Adopted contributions by Christian.
[libdai.git] / index.h
1 /* Copyright (C) 2006-2008 Joris Mooij [j dot mooij at science dot ru dot nl]
2 Radboud University Nijmegen, The Netherlands
3
4 This file is part of libDAI.
5
6 libDAI is free software; you can redistribute it and/or modify
7 it under the terms of the GNU General Public License as published by
8 the Free Software Foundation; either version 2 of the License, or
9 (at your option) any later version.
10
11 libDAI is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 GNU General Public License for more details.
15
16 You should have received a copy of the GNU General Public License
17 along with libDAI; if not, write to the Free Software
18 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
19 */
20
21
22 #ifndef __INDEX_H__
23 #define __INDEX_H__
24
25
26 #include <vector>
27 #include "varset.h"
28
29
30 namespace dai {
31
32
33 /* Example:
34 *
35 * Index i ({s_j_1,s_j_2,...,s_j_m}, {s_1,...,s_N}); // j_k in {1,...,N}
36 * for( ; i>=0; ++i ) {
37 * // loops over all states of (s_1,...,s_N)
38 * // i is linear index of corresponding state of (s_j_1, ..., s_j_m)
39 * }
40 */
41
42
43 class Index
44 {
45 private:
46 long _index;
47 std::vector<int> _count,_max,_sum;
48 public:
49 Index () { _index=-1; };
50 Index (const VarSet& P, const VarSet& ns)
51 {
52 long sum=1;
53 VarSet::const_iterator j=ns.begin();
54 for(VarSet::const_iterator i=P.begin();i!=P.end();++i)
55 {
56 for(;j!=ns.end()&&j->label()<=i->label();++j)
57 {
58 _count.push_back(0);
59 _max.push_back(j->states());
60 _sum.push_back((i->label()==j->label())?sum:0);
61 };
62 sum*=i->states();
63 };
64 for(;j!=ns.end();++j)
65 {
66 _count.push_back(0);
67 _max.push_back(j->states());
68 _sum.push_back(0);
69 };
70 _index=0;
71 };
72 Index (const Index & ind) : _index(ind._index), _count(ind._count), _max(ind._max), _sum(ind._sum) {};
73 Index & operator=(const Index & ind) {
74 if(this!=&ind) {
75 _index = ind._index;
76 _count = ind._count;
77 _max = ind._max;
78 _sum = ind._sum;
79 }
80 return *this;
81 }
82 Index& clear ()
83 {
84 for(unsigned i=0;i!=_count.size();++i) _count[i]=0;
85 _index=0;
86 return(*this);
87 };
88 operator long () const { return(_index); };
89 Index& operator ++ ()
90 {
91 if(_index>=0)
92 {
93 unsigned i;
94 for(i=0;(i<_count.size())
95 &&(_index+=_sum[i],++_count[i]==_max[i]);++i)
96 {
97 _index-=_sum[i]*_max[i];
98 _count[i]=0;
99 };
100 if(i==_count.size()) _index=-1;
101 };
102 return(*this);
103 };
104 };
105
106
107 class multind {
108 private:
109 std::vector<size_t> _dims; // dimensions
110 std::vector<size_t> _pdims; // products of dimensions
111
112 public:
113 multind(const std::vector<size_t> di) {
114 _dims = di;
115 size_t prod = 1;
116 for( std::vector<size_t>::const_iterator i=di.begin(); i!=di.end(); i++ ) {
117 _pdims.push_back(prod);
118 prod = prod * (*i);
119 }
120 _pdims.push_back(prod);
121 }
122 multind(const VarSet& ns) {
123 _dims.reserve( ns.size() );
124 _pdims.reserve( ns.size() + 1 );
125 size_t prod = 1;
126 for( VarSet::const_iterator n = ns.begin(); n != ns.end(); n++ ) {
127 _pdims.push_back( prod );
128 prod *= n->states();
129 _dims.push_back( n->states() );
130 }
131 _pdims.push_back( prod );
132 }
133 std::vector<size_t> vi(size_t li) const { // linear index to vector index
134 std::vector<size_t> v(_dims.size(),0);
135 assert(li < _pdims.back());
136 for( long j = v.size()-1; j >= 0; j-- ) {
137 size_t q = li / _pdims[j];
138 v[j] = q;
139 li = li - q * _pdims[j];
140 }
141 return v;
142 }
143 size_t li(const std::vector<size_t> vi) const { // linear index
144 size_t s = 0;
145 assert(vi.size() == _dims.size());
146 for( size_t j = 0; j < vi.size(); j++ )
147 s += vi[j] * _pdims[j];
148 return s;
149 }
150 size_t max() const { return( _pdims.back() ); };
151
152 // FIXME add an iterator, which increases a vector index just using addition
153 };
154
155
156 }
157
158
159 #endif