Changed license from GPL v2+ to FreeBSD (aka BSD 2-clause) license
[libdai.git] / examples / example_sprinkler_em.cpp
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * Copyright (c) 2006-2011, The libDAI authors. All rights reserved.
4 *
5 * Use of this source code is governed by a BSD-style license that can be found in the LICENSE file.
6 */
7
8 #include <dai/alldai.h>
9 #include <iostream>
10 #include <fstream>
11 #include <string>
12
13 using namespace std;
14 using namespace dai;
15
16 int main() {
17 // This example program illustrates how to learn the
18 // parameters of a Bayesian network from a sample of
19 // the sprinkler network discussed at
20 // http://www.cs.ubc.ca/~murphyk/Bayes/bnintro.html
21 //
22 // The factor graph file (sprinkler.fg) has to be generated first
23 // by running example_sprinkler, and the data sample file
24 // (sprinkler.tab) by running example_sprinkler_gibbs
25
26 // Read the factorgraph from the file
27 FactorGraph SprinklerNetwork;
28 SprinklerNetwork.ReadFromFile( "sprinkler.fg" );
29
30 // Prepare junction-tree object for doing exact inference for E-step
31 PropertySet infprops;
32 infprops.set( "verbose", (size_t)1 );
33 infprops.set( "updates", string("HUGIN") );
34 InfAlg* inf = newInfAlg( "JTREE", SprinklerNetwork, infprops );
35 inf->init();
36
37 // Read sample from file
38 Evidence e;
39 ifstream estream( "sprinkler.tab" );
40 e.addEvidenceTabFile( estream, SprinklerNetwork );
41 cout << "Number of samples: " << e.nrSamples() << endl;
42
43 // Read EM specification
44 ifstream emstream( "sprinkler.em" );
45 EMAlg em(e, *inf, emstream);
46
47 // Iterate EM until convergence
48 while( !em.hasSatisfiedTermConditions() ) {
49 Real l = em.iterate();
50 cout << "Iteration " << em.Iterations() << " likelihood: " << l <<endl;
51 }
52
53 // Output true factor graph
54 cout << endl << "True factor graph:" << endl << "##################" << endl;
55 cout.precision(12);
56 cout << SprinklerNetwork;
57
58 // Output learned factor graph
59 cout << endl << "Learned factor graph:" << endl << "#####################" << endl;
60 cout.precision(12);
61 cout << inf->fg();
62
63 // Clean up
64 delete inf;
65
66 return 0;
67 }