[Patrick Pletscher] Fixed performance issue in FactorGraph::clamp and FactorGraph...
[libdai.git] / src / evidence.cpp
1 /* Copyright (C) 2009 Charles Vaske [cvaske at soe dot ucsc dot edu]
2 University of California Santa Cruz
3
4 This file is part of libDAI.
5
6 libDAI is free software; you can redistribute it and/or modify
7 it under the terms of the GNU General Public License as published by
8 the Free Software Foundation; either version 2 of the License, or
9 (at your option) any later version.
10
11 libDAI is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 GNU General Public License for more details.
15
16 You should have received a copy of the GNU General Public License
17 along with libDAI; if not, write to the Free Software
18 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
19 */
20
21
22 #include <sstream>
23 #include <string>
24 #include <cstdlib>
25
26 #include <dai/util.h>
27 #include <dai/evidence.h>
28
29
30 namespace dai {
31
32
33 void Observation::addObservation( Var node, size_t setting ) {
34 _obs[node] = setting;
35 }
36
37
38 void Observation::applyEvidence( InfAlg &alg ) const {
39 for( std::map<Var, size_t>::const_iterator i = _obs.begin(); i != _obs.end(); ++i )
40 alg.clamp( i->first, i->second );
41 }
42
43
44 void Evidence::addEvidenceTabFile( std::istream &is, FactorGraph &fg ) {
45 std::map<std::string, Var> varMap;
46 for( std::vector<Var>::const_iterator v = fg.vars().begin(); v != fg.vars().end(); ++v ) {
47 std::stringstream s;
48 s << v->label();
49 varMap[s.str()] = *v;
50 }
51
52 addEvidenceTabFile( is, varMap );
53 }
54
55
56 void Evidence::addEvidenceTabFile( std::istream &is, std::map<std::string, Var> &varMap ) {
57 std::string line;
58 getline( is, line );
59
60 // Parse header
61 std::vector<std::string> header_fields;
62 tokenizeString( line, header_fields );
63 std::vector<std::string>::const_iterator p_field = header_fields.begin();
64 if( p_field == header_fields.end() )
65 DAI_THROW(INVALID_EVIDENCE_LINE);
66
67 std::vector<Var> vars;
68 for( ; p_field != header_fields.end(); ++p_field ) {
69 std::map<std::string, Var>::iterator elem = varMap.find( *p_field );
70 if( elem == varMap.end() )
71 DAI_THROW(INVALID_EVIDENCE_FILE);
72 vars.push_back( elem->second );
73 }
74
75 // Read samples
76 while( getline(is, line) ) {
77 std::vector<std::string> fields;
78 tokenizeString( line, fields );
79 if( fields.size() != vars.size() )
80 DAI_THROW(INVALID_EVIDENCE_LINE);
81
82 Observation sampleData;
83 for( size_t i = 0; i < vars.size(); ++i ) {
84 if( fields[i].size() > 0 ) { // skip if missing observation
85 if( fields[i].find_first_not_of("0123456789") != std::string::npos )
86 DAI_THROW(INVALID_EVIDENCE_OBSERVATION);
87 size_t state = atoi( fields[i].c_str() );
88 if( state >= vars[i].states() )
89 DAI_THROW(INVALID_EVIDENCE_OBSERVATION);
90 sampleData.addObservation( vars[i], state );
91 }
92 }
93 _samples.push_back( sampleData );
94 } // finished sample line
95 }
96
97
98 } // end of namespace dai