Fixed testem failure caused by rounding error
[libdai.git] / src / evidence.cpp
1 /* Copyright (C) 2009 Charles Vaske [cvaske at soe dot ucsc dot edu]
2 University of California Santa Cruz
3
4 This file is part of libDAI.
5
6 libDAI is free software; you can redistribute it and/or modify
7 it under the terms of the GNU General Public License as published by
8 the Free Software Foundation; either version 2 of the License, or
9 (at your option) any later version.
10
11 libDAI is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 GNU General Public License for more details.
15
16 You should have received a copy of the GNU General Public License
17 along with libDAI; if not, write to the Free Software
18 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
19 */
20
21
22 #include <sstream>
23 #include <string>
24 #include <cstdlib>
25
26 #include <dai/util.h>
27 #include <dai/evidence.h>
28
29
30 namespace dai {
31
32
33 void Observation::addObservation( Var node, size_t setting ) {
34 _obs[node] = setting;
35 }
36
37
38 void Observation::applyEvidence( InfAlg& alg ) const {
39 std::map<Var, size_t>::const_iterator i = _obs.begin();
40 for( ; i != _obs.end(); ++i )
41 alg.clamp( i->first, i->second );
42 }
43
44
45 void Evidence::addEvidenceTabFile( std::istream& is, FactorGraph& fg ) {
46 std::map<std::string, Var> varMap;
47 std::vector<Var>::const_iterator v = fg.vars().begin();
48 for( ; v != fg.vars().end(); ++v ) {
49 std::stringstream s;
50 s << v->label();
51 varMap[s.str()] = *v;
52 }
53
54 addEvidenceTabFile( is, varMap );
55 }
56
57
58 void Evidence::addEvidenceTabFile( std::istream& is, std::map<std::string, Var>& varMap ) {
59 std::vector<std::string> header_fields;
60 std::vector<Var> vars;
61 std::string line;
62 getline( is, line );
63
64 // Parse header
65 tokenizeString( line, header_fields );
66 std::vector<std::string>::const_iterator p_field = header_fields.begin();
67
68 if( p_field == header_fields.end() )
69 DAI_THROW(INVALID_EVIDENCE_LINE);
70
71 for( ; p_field != header_fields.end(); ++p_field ) {
72 std::map<std::string, Var>::iterator elem = varMap.find( *p_field );
73 if( elem == varMap.end() )
74 DAI_THROW(INVALID_EVIDENCE_FILE);
75 vars.push_back( elem->second );
76 }
77
78 // Read samples
79 while( getline(is, line) ) {
80 std::vector<std::string> fields;
81
82 tokenizeString( line, fields );
83
84 if( fields.size() != vars.size() )
85 DAI_THROW(INVALID_EVIDENCE_LINE);
86
87 Observation sampleData;
88 for( size_t i = 0; i < vars.size(); ++i ) {
89 if( fields[i].size() > 0 ) { // skip if missing observation
90 if( fields[i].find_first_not_of("0123456789") != std::string::npos )
91 DAI_THROW(INVALID_EVIDENCE_OBSERVATION);
92 size_t state = atoi( fields[i].c_str() );
93 if( state >= vars[i].states() )
94 DAI_THROW(INVALID_EVIDENCE_OBSERVATION);
95 sampleData.addObservation( vars[i], state );
96 }
97 }
98 _samples.push_back( sampleData );
99 } // finished sample line
100 }
101
102
103 } // end of namespace dai