blob: 7bfd3b8166de0ffe47eaaa6190bbd9ea4702b8a1 [file] [log] [blame]
Wei Hua6b4eebc2012-03-09 10:24:16 -08001/*
2 * Copyright (C) 2012 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17// Purpose: A container for sparse weight vectors
18// Maintains the sparse vector as a list of (name, value) pairs alongwith
19// a normalizer_. All operations assume that (name, value/normalizer_) is the
20// true value in question.
21
22#ifndef LEARNING_STOCHASTIC_LINEAR_SPARSE_WEIGHT_VECTOR_H_
23#define LEARNING_STOCHASTIC_LINEAR_SPARSE_WEIGHT_VECTOR_H_
24
25#include <hash_map>
26#include <iosfwd>
27#include <math.h>
28#include <sstream>
29#include <string>
30
31#include "common_defs.h"
32
33namespace learning_stochastic_linear {
34
35template<class Key = std::string, class Hash = std::hash_map<Key, double> >
36class SparseWeightVector {
37 public:
38 typedef Hash Wmap;
39 typedef typename Wmap::iterator Witer;
40 typedef typename Wmap::const_iterator Witer_const;
41 SparseWeightVector() {
42 normalizer_ = 1.0;
43 }
44 ~SparseWeightVector() {}
45 explicit SparseWeightVector(const SparseWeightVector<Key, Hash> &other) {
46 CopyFrom(other);
47 }
48 void operator=(const SparseWeightVector<Key, Hash> &other) {
49 CopyFrom(other);
50 }
51 void CopyFrom(const SparseWeightVector<Key, Hash> &other) {
52 w_ = other.w_;
53 wmin_ = other.wmin_;
54 wmax_ = other.wmax_;
55 normalizer_ = other.normalizer_;
56 }
57
58 // This function implements checks to prevent unbounded vectors. It returns
59 // true if the checks succeed and false otherwise. A vector is deemed invalid
60 // if any of these conditions are met:
61 // 1. it has no values.
62 // 2. its normalizer is nan or inf or close to zero.
63 // 3. any of its values are nan or inf.
64 // 4. its L0 norm is close to zero.
65 bool IsValid() const;
66
67 // Normalizer getters and setters.
68 double GetNormalizer() const {
69 return normalizer_;
70 }
71 void SetNormalizer(const double norm) {
72 normalizer_ = norm;
73 }
74 void NormalizerMultUpdate(const double mul) {
75 normalizer_ = normalizer_ * mul;
76 }
77 void NormalizerAddUpdate(const double add) {
78 normalizer_ += add;
79 }
80
81 // Divides all the values by the normalizer, then it resets it to 1.0
82 void ResetNormalizer();
83
84 // Bound getters and setters.
85 // True if there is a bound with val containing the bound. false otherwise.
86 bool GetElementMinBound(const Key &fname, double *val) const {
87 return GetValue(wmin_, fname, val);
88 }
89 bool GetElementMaxBound(const Key &fname, double *val) const {
90 return GetValue(wmax_, fname, val);
91 }
92 void SetElementMinBound(const Key &fname, const double bound) {
93 wmin_[fname] = bound;
94 }
95 void SetElementMaxBound(const Key &fname, const double bound) {
96 wmax_[fname] = bound;
97 }
98 // Element getters and setters.
99 double GetElement(const Key &fname) const {
100 double val = 0;
101 GetValue(w_, fname, &val);
102 return val;
103 }
104 void SetElement(const Key &fname, const double val) {
105 //DCHECK(!isnan(val));
106 w_[fname] = val;
107 }
108 void AddUpdateElement(const Key &fname, const double val) {
109 w_[fname] += val;
110 }
111 void MultUpdateElement(const Key &fname, const double val) {
112 w_[fname] *= val;
113 }
114 // Load another weight vectors. Will overwrite the current vector.
115 void LoadWeightVector(const SparseWeightVector<Key, Hash> &vec) {
116 w_.clear();
117 w_.insert(vec.w_.begin(), vec.w_.end());
118 wmax_.insert(vec.wmax_.begin(), vec.wmax_.end());
119 wmin_.insert(vec.wmin_.begin(), vec.wmin_.end());
120 normalizer_ = vec.normalizer_;
121 }
122 void Clear() {
123 w_.clear();
124 wmax_.clear();
125 wmin_.clear();
126 }
127 const Wmap& GetMap() const {
128 return w_;
129 }
130 // Vector Operations.
131 void AdditiveWeightUpdate(const double multiplier,
132 const SparseWeightVector<Key, Hash> &w1,
133 const double additive_const);
134 void AdditiveSquaredWeightUpdate(const double multiplier,
135 const SparseWeightVector<Key, Hash> &w1,
136 const double additive_const);
137 void AdditiveInvSqrtWeightUpdate(const double multiplier,
138 const SparseWeightVector<Key, Hash> &w1,
139 const double additive_const);
140 void MultWeightUpdate(const SparseWeightVector<Key, Hash> &w1);
141 double DotProduct(const SparseWeightVector<Key, Hash> &s) const;
142 // L-x norm. eg. L1, L2.
143 double LxNorm(const double x) const;
144 double L2Norm() const;
145 double L1Norm() const;
146 double L0Norm(const double epsilon) const;
147 // Bound preserving updates.
148 void AdditiveWeightUpdateBounded(const double multiplier,
149 const SparseWeightVector<Key, Hash> &w1,
150 const double additive_const);
151 void MultWeightUpdateBounded(const SparseWeightVector<Key, Hash> &w1);
152 void ReprojectToBounds();
153 void ReprojectL0(const double l0_norm);
154 void ReprojectL1(const double l1_norm);
155 void ReprojectL2(const double l2_norm);
156 // Reproject using the given norm.
157 // Will also rescale regularizer_ if it gets too small/large.
158 int32 Reproject(const double norm, const RegularizationType r);
159 // Convert this vector to a string, simply for debugging.
160 std::string DebugString() const {
161 std::stringstream stream;
162 stream << *this;
163 return stream.str();
164 }
165 private:
166 // The weight map.
167 Wmap w_;
168 // Constraint bounds.
169 Wmap wmin_;
170 Wmap wmax_;
171 // Normalizing constant in magnitude measurement.
172 double normalizer_;
173 // This function in necessary since by default hash_map inserts an element
174 // if it does not find the key through [] operator. It implements a lookup
175 // without the space overhead of an add.
176 bool GetValue(const Wmap &w1, const Key &fname, double *val) const {
177 Witer_const iter = w1.find(fname);
178 if (iter != w1.end()) {
179 (*val) = iter->second;
180 return true;
181 } else {
182 (*val) = 0;
183 return false;
184 }
185 }
186};
187
188// Outputs a SparseWeightVector, for debugging.
189template <class Key, class Hash>
190std::ostream& operator<<(std::ostream &stream,
191 const SparseWeightVector<Key, Hash> &vector) {
192 typename SparseWeightVector<Key, Hash>::Wmap w_map = vector.GetMap();
193 stream << "[[ ";
194 for (typename SparseWeightVector<Key, Hash>::Witer_const iter = w_map.begin();
195 iter != w_map.end();
196 ++iter) {
197 stream << "<" << iter->first << ", " << iter->second << "> ";
198 }
199 return stream << " ]]";
200};
201
202} // namespace learning_stochastic_linear
203#endif // LEARNING_STOCHASTIC_LINEAR_SPARSE_WEIGHT_VECTOR_H_