Merge branch 'pletscher'
[libdai.git] / src / evidence.cpp
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
6 *
7 * Copyright (C) 2009 Charles Vaske [cvaske at soe dot ucsc dot edu]
8 * Copyright (C) 2009 University of California, Santa Cruz
9 */
10
11
12 #include <sstream>
13 #include <string>
14 #include <cstdlib>
15
16 #include <dai/util.h>
17 #include <dai/evidence.h>
18
19
20 namespace dai {
21
22
23 void Observation::addObservation( Var node, size_t setting ) {
24 _obs[node] = setting;
25 }
26
27
28 void Observation::applyEvidence( InfAlg &alg ) const {
29 for( std::map<Var, size_t>::const_iterator i = _obs.begin(); i != _obs.end(); ++i )
30 alg.clamp( alg.fg().findVar(i->first), i->second );
31 }
32
33
34 void Evidence::addEvidenceTabFile( std::istream &is, FactorGraph &fg ) {
35 std::map<std::string, Var> varMap;
36 for( std::vector<Var>::const_iterator v = fg.vars().begin(); v != fg.vars().end(); ++v ) {
37 std::stringstream s;
38 s << v->label();
39 varMap[s.str()] = *v;
40 }
41
42 addEvidenceTabFile( is, varMap );
43 }
44
45
46 void Evidence::addEvidenceTabFile( std::istream &is, std::map<std::string, Var> &varMap ) {
47 std::string line;
48 getline( is, line );
49
50 // Parse header
51 std::vector<std::string> header_fields;
52 tokenizeString( line, header_fields );
53 std::vector<std::string>::const_iterator p_field = header_fields.begin();
54 if( p_field == header_fields.end() )
55 DAI_THROW(INVALID_EVIDENCE_FILE);
56
57 std::vector<Var> vars;
58 for( ; p_field != header_fields.end(); ++p_field ) {
59 std::map<std::string, Var>::iterator elem = varMap.find( *p_field );
60 if( elem == varMap.end() )
61 DAI_THROW(INVALID_EVIDENCE_FILE);
62 vars.push_back( elem->second );
63 }
64
65 // Read samples
66 while( getline(is, line) ) {
67 std::vector<std::string> fields;
68 tokenizeString( line, fields );
69 if( fields.size() != vars.size() )
70 DAI_THROW(INVALID_EVIDENCE_FILE);
71
72 Observation sampleData;
73 for( size_t i = 0; i < vars.size(); ++i ) {
74 if( fields[i].size() > 0 ) { // skip if missing observation
75 if( fields[i].find_first_not_of("0123456789") != std::string::npos )
76 DAI_THROW(INVALID_EVIDENCE_FILE);
77 size_t state = atoi( fields[i].c_str() );
78 if( state >= vars[i].states() )
79 DAI_THROW(INVALID_EVIDENCE_FILE);
80 sampleData.addObservation( vars[i], state );
81 }
82 }
83 _samples.push_back( sampleData );
84 } // finished sample line
85 }
86
87
88 } // end of namespace dai