3 //#include "features/CharFeatures.h"
4 //#include "features/StringFeatures.h"
12 #include "fill_matrix.h"
13 #include "penalty_info.h"
15 void init_penalty_struct(struct penalty_struct
&PEN
)
21 PEN
.transform
= T_LINEAR
;
29 void init_penalty_struct_cache(struct penalty_struct
&PEN
)
31 if (PEN
.cache
|| PEN
.use_svm
)
34 REAL
* cache
=new REAL
[PEN
.max_len
+1] ;
37 for (INT i
=0; i
<=PEN
.max_len
; i
++)
38 cache
[i
] = lookup_penalty(&PEN
, i
, 0, false) ;
43 void delete_penalty_struct(struct penalty_struct
&PEN
)
48 delete[] PEN
.penalties
;
54 void delete_penalty_struct_array(struct penalty_struct
*PEN
, INT len
)
56 for (int i
=0; i
<len
; i
++)
57 delete_penalty_struct(PEN
[i
]) ;
62 //struct penalty_struct * read_penalty_struct_from_cell(const mxArray * mx_penalty_info, int &P)
64 // P = mxGetN(mx_penalty_info) ;
66 // struct penalty_struct * PEN = new struct penalty_struct[P] ;
67 // for (INT i=0; i<P; i++)
68 // init_penalty_struct(PEN[i]) ;
70 // for (INT i=0; i<P; i++)
72 // const mxArray* mx_elem = mxGetCell(mx_penalty_info, i) ;
73 // if (mx_elem==NULL || !mxIsStruct(mx_elem))
75 // CIO::message(M_ERROR, "empty cell element\n") ;
76 // delete_penalty_struct_array(PEN,P) ;
79 // const mxArray* mx_id_field = mxGetField(mx_elem, 0, "id") ;
80 // if (mx_id_field==NULL || !mxIsNumeric(mx_id_field) ||
81 // mxGetN(mx_id_field)!=1 || mxGetM(mx_id_field)!=1)
83 // CIO::message(M_ERROR, "missing id field\n") ;
84 // delete_penalty_struct_array(PEN,P) ;
87 // const mxArray* mx_limits_field = mxGetField(mx_elem, 0, "limits") ;
88 // if (mx_limits_field==NULL || !mxIsNumeric(mx_limits_field) ||
89 // mxGetM(mx_limits_field)!=1)
91 // CIO::message(M_ERROR, "missing limits field\n") ;
92 // delete_penalty_struct_array(PEN,P) ;
95 // INT len = mxGetN(mx_limits_field) ;
97 // const mxArray* mx_penalties_field = mxGetField(mx_elem, 0, "penalties") ;
98 // if (mx_penalties_field==NULL || !mxIsNumeric(mx_penalties_field) ||
99 // mxGetM(mx_penalties_field)!=1 || mxGetN(mx_penalties_field)!=len)
101 // CIO::message(M_ERROR, "missing penalties field\n") ;
102 // delete_penalty_struct_array(PEN,P) ;
105 // const mxArray* mx_transform_field = mxGetField(mx_elem, 0, "transform") ;
106 // if (mx_transform_field==NULL || !mxIsChar(mx_transform_field))
108 // CIO::message(M_ERROR, "missing transform field\n") ;
109 // delete_penalty_struct_array(PEN,P) ;
112 // const mxArray* mx_name_field = mxGetField(mx_elem, 0, "name") ;
113 // if (mx_name_field==NULL || !mxIsChar(mx_name_field))
115 // CIO::message(M_ERROR, "missing name field\n") ;
116 // delete_penalty_struct_array(PEN,P) ;
119 // const mxArray* mx_max_len_field = mxGetField(mx_elem, 0, "max_len") ;
120 // if (mx_max_len_field==NULL || !mxIsNumeric(mx_max_len_field) ||
121 // mxGetM(mx_max_len_field)!=1 || mxGetN(mx_max_len_field)!=1)
123 // CIO::message(M_ERROR, "missing max_len field\n") ;
124 // delete_penalty_struct_array(PEN,P) ;
127 // const mxArray* mx_min_len_field = mxGetField(mx_elem, 0, "min_len") ;
128 // if (mx_min_len_field==NULL || !mxIsNumeric(mx_min_len_field) ||
129 // mxGetM(mx_min_len_field)!=1 || mxGetN(mx_min_len_field)!=1)
131 // CIO::message(M_ERROR, "missing min_len field\n") ;
132 // delete_penalty_struct_array(PEN,P) ;
135 // const mxArray* mx_use_svm_field = mxGetField(mx_elem, 0, "use_svm") ;
136 // if (mx_use_svm_field==NULL || !mxIsNumeric(mx_use_svm_field) ||
137 // mxGetM(mx_use_svm_field)!=1 || mxGetN(mx_use_svm_field)!=1)
139 // CIO::message(M_ERROR, "missing use_svm field\n") ;
140 // delete_penalty_struct_array(PEN,P) ;
143 // INT use_svm = (INT) mxGetScalar(mx_use_svm_field) ;
144 // //fprintf(stderr, "use_svm_field=%i\n", use_svm) ;
146 // const mxArray* mx_next_id_field = mxGetField(mx_elem, 0, "next_id") ;
147 // if (mx_next_id_field==NULL || !mxIsNumeric(mx_next_id_field) ||
148 // mxGetM(mx_next_id_field)!=1 || mxGetN(mx_next_id_field)!=1)
150 // CIO::message(M_ERROR, "missing next_id field\n") ;
151 // delete_penalty_struct_array(PEN,P) ;
154 // INT next_id = (INT) mxGetScalar(mx_next_id_field)-1 ;
156 // INT id = (INT) mxGetScalar(mx_id_field)-1 ;
159 // CIO::message(M_ERROR, "id out of range\n") ;
160 // delete_penalty_struct_array(PEN,P) ;
163 // INT max_len = (INT) mxGetScalar(mx_max_len_field) ;
164 // if (max_len<0 || max_len>1024*1024*100)
166 // CIO::message(M_ERROR, "max_len out of range\n") ;
167 // delete_penalty_struct_array(PEN,P) ;
170 // PEN[id].max_len = max_len ;
172 // INT min_len = (INT) mxGetScalar(mx_min_len_field) ;
173 // if (min_len<0 || min_len>1024*1024*100)
175 // CIO::message(M_ERROR, "min_len out of range\n") ;
176 // delete_penalty_struct_array(PEN,P) ;
179 // PEN[id].min_len = min_len ;
181 // if (PEN[id].id!=-1)
183 // CIO::message(M_ERROR, "penalty id already used\n") ;
184 // delete_penalty_struct_array(PEN,P) ;
189 // PEN[id].next_pen=&PEN[next_id] ;
190 // //fprintf(stderr,"id=%i, next_id=%i\n", id, next_id) ;
192 // assert(next_id!=id) ;
193 // PEN[id].use_svm=use_svm ;
194 // PEN[id].limits = new REAL[len] ;
195 // PEN[id].penalties = new REAL[len] ;
196 // double * limits = mxGetPr(mx_limits_field) ;
197 // double * penalties = mxGetPr(mx_penalties_field) ;
199 // for (INT i=0; i<len; i++)
201 // PEN[id].limits[i]=limits[i] ;
202 // PEN[id].penalties[i]=penalties[i] ;
204 // PEN[id].len = len ;
206 // char *transform_str = mxArrayToString(mx_transform_field) ;
207 // char *name_str = mxArrayToString(mx_name_field) ;
209 // if (strcmp(transform_str, "log")==0)
210 // PEN[id].transform = T_LOG ;
211 // else if (strcmp(transform_str, "log(+1)")==0)
212 // PEN[id].transform = T_LOG_PLUS1 ;
213 // else if (strcmp(transform_str, "log(+3)")==0)
214 // PEN[id].transform = T_LOG_PLUS3 ;
215 // else if (strcmp(transform_str, "(+3)")==0)
216 // PEN[id].transform = T_LINEAR_PLUS3 ;
217 // else if (strcmp(transform_str, "")==0)
218 // PEN[id].transform = T_LINEAR ;
221 // delete_penalty_struct_array(PEN,P) ;
222 // mxFree(transform_str) ;
225 // PEN[id].name = new char[strlen(name_str)+1] ;
226 // strcpy(PEN[id].name, name_str) ;
228 // init_penalty_struct_cache(PEN[id]) ;
230 // mxFree(transform_str) ;
231 // mxFree(name_str) ;
236 REAL
lookup_penalty_svm(const struct penalty_struct
*PEN
, INT p_value
, REAL
*d_values
)
240 assert(PEN
->use_svm
>0) ;
241 REAL d_value
=d_values
[PEN
->use_svm
-1] ;
242 //fprintf(stderr,"transform=%i, d_value=%1.2f\n", (INT)PEN->transform, d_value) ;
244 switch (PEN
->transform
)
249 d_value
= log(d_value
) ;
252 d_value
= log(d_value
+1) ;
255 d_value
= log(d_value
+3) ;
258 d_value
= d_value
+3 ;
261 CIO::message(M_ERROR
, "unknown transform\n") ;
267 for (INT i
=0; i
<PEN
->len
; i
++)
268 if (PEN
->limits
[i
]<=d_value
)
272 ret
=PEN
->penalties
[0] ;
273 else if (idx
==PEN
->len
)
274 ret
=PEN
->penalties
[PEN
->len
-1] ;
277 ret
= (PEN
->penalties
[idx
]*(d_value
-PEN
->limits
[idx
-1]) + PEN
->penalties
[idx
-1]*
278 (PEN
->limits
[idx
]-d_value
)) / (PEN
->limits
[idx
]-PEN
->limits
[idx
-1]) ;
281 //fprintf(stderr,"ret=%1.2f\n", ret) ;
284 ret
+=lookup_penalty(PEN
->next_pen
, p_value
, d_values
);
286 //fprintf(stderr,"ret=%1.2f\n", ret) ;
291 REAL
lookup_penalty(const struct penalty_struct
*PEN
, INT p_value
,
292 REAL
* svm_values
, bool follow_next
)
297 return lookup_penalty_svm(PEN
, p_value
, svm_values
) ;
299 if ((p_value
<PEN
->min_len
) || (p_value
>PEN
->max_len
))
300 return -CMath::INFTY
;
302 if (PEN
->cache
!=NULL
&& (p_value
>=0) && (p_value
<=PEN
->max_len
))
304 REAL ret
=PEN
->cache
[p_value
] ;
305 if (PEN
->next_pen
&& follow_next
)
306 ret
+=lookup_penalty(PEN
->next_pen
, p_value
, svm_values
);
310 REAL d_value
= (REAL
) p_value
;
311 switch (PEN
->transform
)
316 d_value
= log(d_value
) ;
319 d_value
= log(d_value
+1) ;
322 d_value
= log(d_value
+3) ;
325 d_value
= d_value
+3 ;
328 CIO::message(M_ERROR
, "unknown transform\n") ;
334 for (INT i
=0; i
<PEN
->len
; i
++)
335 if (PEN
->limits
[i
]<=d_value
)
339 ret
=PEN
->penalties
[0] ;
340 else if (idx
==PEN
->len
)
341 ret
=PEN
->penalties
[PEN
->len
-1] ;
344 ret
= (PEN
->penalties
[idx
]*(d_value
-PEN
->limits
[idx
-1]) + PEN
->penalties
[idx
-1]*
345 (PEN
->limits
[idx
]-d_value
)) / (PEN
->limits
[idx
]-PEN
->limits
[idx
-1]) ;
347 //if (p_value>=30 && p_value<150)
348 //fprintf(stderr, "%s %i(%i) -> %1.2f\n", PEN->name, p_value, idx, ret) ;
350 if (PEN
->next_pen
&& follow_next
)
351 ret
+=lookup_penalty(PEN
->next_pen
, p_value
, svm_values
);