edafbdb9079fb20bc8b582971a7569e16a6f9c44
[libdai.git] / src / evidence.cpp
1 #include <sstream>
2 #include <string>
3 #include <cstdlib>
4
5 #include <dai/util.h>
6
7 #include <dai/evidence.h>
8
9 namespace dai {
10
11 void SampleData::addObservation(Var node, size_t setting) {
12 _obs[node] = setting;
13 }
14
15 void SampleData::applyEvidence(InfAlg& alg) const {
16 std::map< Var, size_t>::const_iterator i = _obs.begin();
17 for( ; i != _obs.end(); ++i) {
18 alg.clamp(i->first, i->second);
19 }
20 }
21
22 void Evidence::addEvidenceTabFile(std::istream& is, FactorGraph& fg) {
23 std::map< std::string, Var > varMap;
24 std::vector< Var >::const_iterator v = fg.vars().begin();
25 for(; v != fg.vars().end(); ++v) {
26 std::stringstream s;
27 s << v->label();
28 varMap[s.str()] = *v;
29 }
30
31 addEvidenceTabFile(is, varMap);
32 }
33
34 void Evidence::addEvidenceTabFile(std::istream& is,
35 std::map< std::string, Var >& varMap) {
36
37 std::vector< std::string > header_fields;
38 std::vector< Var > vars;
39 std::string line;
40 getline(is, line);
41
42 // Parse header
43 tokenizeString(line, header_fields);
44 std::vector< std::string >::const_iterator p_field = header_fields.begin();
45
46 if (p_field == header_fields.end()) { DAI_THROW(INVALID_EVIDENCE_LINE); }
47
48 ++p_field; // first column are sample labels
49 for ( ; p_field != header_fields.end(); ++p_field) {
50 std::map< std::string, Var >::iterator elem = varMap.find(*p_field);
51 if (elem == varMap.end()) {
52 DAI_THROW(INVALID_EVIDENCE_FILE);
53 }
54 vars.push_back(elem->second);
55 }
56
57
58 // Read samples
59 while(getline(is, line)) {
60 std::vector< std::string > fields;
61
62 tokenizeString(line, fields);
63
64 if (fields.size() != vars.size() + 1) { DAI_THROW(INVALID_EVIDENCE_LINE); }
65
66 SampleData& sampleData = _samples[fields[0]];
67 sampleData.name(fields[0]); // in case of a new sample
68 for (size_t i = 0; i < vars.size(); ++i) {
69 if (fields[i+1].size() > 0) { // skip if missing observation
70 if (fields[i+1].find_first_not_of("0123456789") != std::string::npos) {
71 DAI_THROW(INVALID_EVIDENCE_OBSERVATION);
72 }
73 size_t state = atoi(fields[i+1].c_str());
74 if (state >= vars[i].states()) {
75 DAI_THROW(INVALID_EVIDENCE_OBSERVATION);
76 }
77 sampleData.addObservation(vars[i], state);
78 }
79 }
80 } // finished sample line
81 }
82
83 }