Multiple changes: changes in build system, one workaround and one bug fix
[libdai.git] / swig / example.py
1 # This file is part of libDAI - http:#www.libdai.org/
2 #
3 # Copyright (c) 2006-2011, The libDAI authors. All rights reserved.
4 #
5 # Use of this source code is governed by a BSD-style license that can be found in the LICENSE file.
6
7
8 # This example program illustrates how to read a factrograph from
9 # a file and run Belief Propagation, Max-Product and JunctionTree on it.
10 # This version uses the SWIG python wrapper of libDAI
11
12
13 import dai
14 import sys
15
16 a = dai.IntVector()
17
18 if len(sys.argv) != 2 and len(sys.argv) != 3:
19 print 'Usage:', sys.argv[0], "<filename.fg> [maxstates]"
20 print 'Reads factor graph <filename.fg> and runs'
21 print 'Belief Propagation, Max-Product and JunctionTree on it.'
22 print 'JunctionTree is only run if a junction tree is found with'
23 print 'total number of states less than <maxstates> (where 0 means unlimited).'
24 sys.exit(1)
25 else:
26 # Report inference algorithms built into libDAI
27 # print 'Builtin inference algorithms:', dai.builtinInfAlgNames()
28 # TODO THIS CRASHES
29
30 # Read FactorGraph from the file specified by the first command line argument
31 fg = dai.FactorGraph()
32 fg.ReadFromFile(sys.argv[1])
33 maxstates = 1000000
34 if len(sys.argv) == 3:
35 maxstates = int(sys.argv[2])
36
37 # Set some constants
38 maxiter = 10000
39 tol = 1e-9
40 verb = 1
41
42 # Store the constants in a PropertySet object
43 opts = dai.PropertySet()
44 opts["maxiter"] = str(maxiter) # Maximum number of iterations
45 opts["tol"] = str(tol) # Tolerance for convergence
46 opts["verbose"] = str(verb) # Verbosity (amount of output generated)
47
48 # Bound treewidth for junctiontree
49 do_jt = True
50 # TODO
51 # try {
52 # boundTreewidth(fg, &eliminationCost_MinFill, maxstates );
53 # } catch( Exception &e ) {
54 # if( e.getCode() == Exception::OUT_OF_MEMORY ) {
55 # do_jt = false;
56 # cout << "Skipping junction tree (need more than " << maxstates << " states)." << endl;
57 # }
58 # else
59 # throw;
60 # }
61
62 if do_jt:
63 # Construct a JTree (junction tree) object from the FactorGraph fg
64 # using the parameters specified by opts and an additional property
65 # that specifies the type of updates the JTree algorithm should perform
66 jtopts = opts
67 jtopts["updates"] = "HUGIN"
68 jt = dai.JTree( fg, jtopts )
69 # Initialize junction tree algorithm
70 jt.init()
71 # Run junction tree algorithm
72 jt.run()
73
74 # Construct another JTree (junction tree) object that is used to calculate
75 # the joint configuration of variables that has maximum probability (MAP state)
76 jtmapopts = opts
77 jtmapopts["updates"] = "HUGIN"
78 jtmapopts["inference"] = "MAXPROD"
79 jtmap = dai.JTree( fg, jtmapopts )
80 # Initialize junction tree algorithm
81 jtmap.init()
82 # Run junction tree algorithm
83 jtmap.run()
84 # Calculate joint state of all variables that has maximum probability
85 jtmapstate = jtmap.findMaximum()
86
87 # Construct a BP (belief propagation) object from the FactorGraph fg
88 # using the parameters specified by opts and two additional properties,
89 # specifying the type of updates the BP algorithm should perform and
90 # whether they should be done in the real or in the logdomain
91 bpopts = opts
92 bpopts["updates"] = "SEQRND"
93 bpopts["logdomain"] = "0"
94 bp = dai.BP( fg, bpopts )
95 # Initialize belief propagation algorithm
96 bp.init()
97 # Run belief propagation algorithm
98 bp.run()
99
100 # Construct a BP (belief propagation) object from the FactorGraph fg
101 # using the parameters specified by opts and two additional properties,
102 # specifying the type of updates the BP algorithm should perform and
103 # whether they should be done in the real or in the logdomain
104 #
105 # Note that inference is set to MAXPROD, which means that the object
106 # will perform the max-product algorithm instead of the sum-product algorithm
107 mpopts = opts
108 mpopts["updates"] = "SEQRND"
109 mpopts["logdomain"] = "0"
110 mpopts["inference"] = "MAXPROD"
111 mpopts["damping"] = "0.1"
112 mp = dai.BP( fg, mpopts )
113 # Initialize max-product algorithm
114 mp.init()
115 # Run max-product algorithm
116 mp.run()
117 # Calculate joint state of all variables that has maximum probability
118 # based on the max-product result
119 mpstate = mp.findMaximum()
120
121 # Construct a decimation algorithm object from the FactorGraph fg
122 # using the parameters specified by opts and three additional properties,
123 # specifying that the decimation algorithm should use the max-product
124 # algorithm and should completely reinitalize its state at every step
125 decmapopts = opts
126 decmapopts["reinit"] = "1"
127 decmapopts["ianame"] = "BP"
128 decmapopts["iaopts"] = "[damping=0.1,inference=MAXPROD,logdomain=0,maxiter=1000,tol=1e-9,updates=SEQRND,verbose=1]"
129 decmap = dai.DecMAP( fg, decmapopts )
130 decmap.init()
131 decmap.run()
132 decmapstate = decmap.findMaximum()
133
134 if do_jt:
135 # Report variable marginals for fg, calculated by the junction tree algorithm
136 print 'Exact variable marginals:'
137 for i in range(fg.nrVars()): # iterate over all variables in fg
138 print jt.belief(dai.VarSet(fg.var(i))) # display the "belief" of jt for that variable
139
140 # Report variable marginals for fg, calculated by the belief propagation algorithm
141 print 'Approximate (loopy belief propagation) variable marginals:'
142 for i in range(fg.nrVars()): # iterate over all variables in fg
143 print bp.belief(dai.VarSet(fg.var(i))) # display the belief of bp for that variable
144
145 if do_jt:
146 # Report factor marginals for fg, calculated by the junction tree algorithm
147 print 'Exact factor marginals:'
148 for I in range(fg.nrFactors()): # iterate over all factors in fg
149 print jt.belief(fg.factor(I).vars()) # display the "belief" of jt for the variables in that factor
150
151 # Report factor marginals for fg, calculated by the belief propagation algorithm
152 print 'Approximate (loopy belief propagation) factor marginals:'
153 for I in range(fg.nrFactors()): # iterate over all factors in fg
154 print bp.belief(fg.factor(I).vars()) # display the belief of bp for the variables in that factor
155
156 if do_jt:
157 # Report log partition sum (normalizing constant) of fg, calculated by the junction tree algorithm
158 print 'Exact log partition sum:', jt.logZ()
159
160 # Report log partition sum of fg, approximated by the belief propagation algorithm
161 print 'Approximate (loopy belief propagation) log partition sum:', bp.logZ()
162
163 if do_jt:
164 # Report exact MAP variable marginals
165 print 'Exact MAP variable marginals:'
166 for i in range(fg.nrVars()):
167 print jtmap.belief(dai.VarSet(fg.var(i)))
168
169 # Report max-product variable marginals
170 print 'Approximate (max-product) MAP variable marginals:'
171 for i in range(fg.nrVars()):
172 print mp.belief(dai.VarSet(fg.var(i)))
173
174 if do_jt:
175 # Report exact MAP factor marginals
176 print 'Exact MAP factor marginals:'
177 for I in range(fg.nrFactors()):
178 print jtmap.belief(fg.factor(I).vars()), '==', jtmap.belief(fg.factor(I).vars())
179
180 # Report max-product factor marginals
181 print 'Approximate (max-product) MAP factor marginals:'
182 for I in range(fg.nrFactors()):
183 print mp.belief(fg.factor(I).vars()), '==', mp.belief(fg.factor(I).vars())
184
185 if do_jt:
186 # Report exact MAP joint state
187 hoie = dai.IntVector()
188 hoie.push_back( 0 )
189 hoie.push_back( 0 )
190 hoie.push_back( 0 )
191 hoie.push_back( 0 )
192 hoie.push_back( 0 )
193 hoie.push_back( 0 )
194 hoie.push_back( 0 )
195 hoie.push_back( 0 )
196 hoie.push_back( 0 )
197 hoie.push_back( 0 )
198 hoie.push_back( 0 )
199 hoie.push_back( 0 )
200 hoie.push_back( 0 )
201 hoie.push_back( 0 )
202 hoie.push_back( 0 )
203 hoie.push_back( 0 )
204 print 'Exact MAP state (log score =', fg.logScore( hoie ), '):'
205 for i in range(len(jtmapstate)):
206 print fg.var(i), ':', jtmapstate[i]
207
208 # Report max-product MAP joint state
209 print 'Approximate (max-product) MAP state (log score =', fg.logScore( mpstate ), '):'
210 for i in range(len(mpstate)):
211 print fg.var(i), ':', mpstate[i]
212
213 # Report DecMAP joint state
214 print 'Approximate DecMAP state (log score =', fg.logScore( decmapstate ), '):'
215 for i in range(len(decmapstate)):
216 print fg.var(i), ':', decmapstate[i]