Improved properties.h/cpp and added unit tests
[libdai.git] / examples / example_sprinkler_em.cpp
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
6 *
7 * Copyright (C) 2010 Joris Mooij [joris dot mooij at libdai dot org]
8 */
9
10 #include <dai/alldai.h>
11 #include <iostream>
12 #include <fstream>
13 #include <string>
14
15 using namespace std;
16 using namespace dai;
17
18 int main() {
19 // This example program illustrates how to learn the
20 // parameters of a Bayesian network from a sample of
21 // the sprinkler network discussed at
22 // http://www.cs.ubc.ca/~murphyk/Bayes/bnintro.html
23 //
24 // The factor graph file (sprinkler.fg) has to be generated first
25 // by running example_sprinkler, and the data sample file
26 // (sprinkler.tab) by running example_sprinkler_gibbs
27
28 // Read the factorgraph from the file
29 FactorGraph SprinklerNetwork;
30 SprinklerNetwork.ReadFromFile( "sprinkler.fg" );
31
32 // Prepare junction-tree object for doing exact inference for E-step
33 PropertySet infprops;
34 infprops.set( "verbose", (size_t)1 );
35 infprops.set( "updates", string("HUGIN") );
36 InfAlg* inf = newInfAlg( "JTREE", SprinklerNetwork, infprops );
37 inf->init();
38
39 // Read sample from file
40 Evidence e;
41 ifstream estream( "sprinkler.tab" );
42 e.addEvidenceTabFile( estream, SprinklerNetwork );
43 cout << "Number of samples: " << e.nrSamples() << endl;
44
45 // Read EM specification
46 ifstream emstream( "sprinkler.em" );
47 EMAlg em(e, *inf, emstream);
48
49 // Iterate EM until convergence
50 while( !em.hasSatisfiedTermConditions() ) {
51 Real l = em.iterate();
52 cout << "Iteration " << em.Iterations() << " likelihood: " << l <<endl;
53 }
54
55 // Output true factor graph
56 cout << endl << "True factor graph:" << endl << "##################" << endl;
57 cout.precision(12);
58 cout << SprinklerNetwork;
59
60 // Output learned factor graph
61 cout << endl << "Learned factor graph:" << endl << "#####################" << endl;
62 cout.precision(12);
63 cout << inf->fg();
64
65 // Clean up
66 delete inf;
67
68 return 0;
69 }