blob: 18fa17058ddc638c496959e1bfbf64a7eaf8c1fd [file] [log] [blame]
Kuang-che Wu88875db2017-07-20 10:47:53 +08001# Copyright 2017 The Chromium OS Authors. All rights reserved.
2# Use of this source code is governed by a BSD-style license that can be
3# found in the LICENSE file.
4"""Math utilities."""
5
6from __future__ import print_function
7
8import copy
9import math
10import sys
11
12
13def _log(x):
14 """Wrapper of log(x) handling zero and negative value."""
15 # Due to arithmetic error, x may become tiny negative.
16 if abs(x) < 1e-15:
17 return 0.0
18 return math.log(x)
19
20
21def _xlogx(x):
22 return x * _log(x)
23
24
25def least(values):
26 """Finds minimum non-zero value.
27
28 Args:
29 values: Non-negative values.
30
31 Returns:
32 The minimum non-zero value.
33
34 Raises:
35 ValueError: if all values are zero or negative.
36 """
37 return min(x for x in values if x > 0)
38
39
40class EntropyStats(object):
41 r"""A data structure to maintain cross entropy of a set of values.
42
43 This is a math helper for NoisyBinarySearch. Given a set of probability,
44 this class can incrementally calculate their cross entropy.
45
46 Algorithm:
47 Given unnormalized p_i, their cross entropy could be expressed as
48 Let S = \sum_i p_i
49 Normalized(p): p_i' = p_i/S
50 CrossEntropy(p) = \sum_i { -(p_i/S) * log (p_i/S) }
51 = \sum_i { -(p_i/S) (log p_i - log S) }
52 = \sum_i { -p_i/S * log p_i } + \sum_i { p_i/S * log S }
53 = -1/S * \sum_i { p_i*log p_i } + log S
54 = log S - 1/S * \sum_i { p_i*log p_i }
55
56 So we can maintain sum=|S| and sum_log=|\sum_i { p_i*log p_i }| for
57 incremental update.
58 """
59
60 def __init__(self, p):
61 """Initializes EntropyStats.
62
63 Args:
64 p: A list of probability value. Could be unnormalized.
65 """
66 self.sum = sum(p)
67 self.sum_log = sum(map(_xlogx, p))
68
69 def entropy(self):
70 """Returns cross entropy."""
71 if self.sum == 0:
72 return 0
73 return _log(self.sum) - self.sum_log / self.sum
74
75 def replace(self, old, new):
76 """Replaces one random variable in the collection with another.
77
78 Args:
79 old: original value to be replaced.
80 new: new value
81 """
82 self.sum += new - old
83 self.sum_log += _xlogx(new) - _xlogx(old)
84
85 def multiply(self, value):
86 """Multiplies all values in the collection by |value|.
87
88 Returns:
89 A new instance of EntropyStats with calculated value.
90 """
91 other = copy.copy(self)
92 # \sum_i { (p_i*value) * log (p_i*value) }
93 # = \sum_i { (p_i*value) * (log p_i + log value) }
94 # = value * \sum_i { p_i log p_i } + \sum_i { p_i * value * log value }
95 # = value * \sum_i { p_i log p_i } + \sum_i { p_i } * value * log value
96 other.sum_log = value * self.sum_log + self.sum * _xlogx(value)
97
98 other.sum *= value
99 return other
100
101
102class ExtendedFloat(object):
103 """Custom floating point number with larger range of exponent.
104
105 Problem:
106 In the formula of noisy binary search algorithm, we need to calculate p**n
107 (where n is test count) and normalize probabilities. When n is hundreds or
108 more, p**n goes too small and becomes zero (underflow), the normalization
109 step will fail.
110
111 Since n > 100 is not unreasonable large for noisy bisection, we need a
112 numeric type less likely or never underflow.
113
114 Solution:
115 Supports a floating number is represented as
116 x = mantissa * 2.**exponent
117 We use python's float to represent `mantissa` and python's int to represent
118 `exponent`.
119
120 For mantissa, we can accept calculated probability is not super accurate
121 because the worst case is just running test one or two more times. So we
122 keep mantissa in python's built-in float.
123
124 For exponent, we represent it using python's int, which is virtually
125 unlimited digits.
126
127 Alternative solutions:
128 1. Maybe we can rewrite the formula to avoid underflow. But I want to keep
129 the arithmetic code simple.
130
131 2. Use python stdlib's decimal or fractions. But they are very slow --
132 decimal is 6-50x slower, fractions is 9-760x slower while ExtendedFloat
133 is about 1.2-3.3x slower [1].
134
135 [1] test setup: easy N=1000, oracle=(0.01, 0.2)
136 hard N=1000, oracle=(0.45, 0.55)
137
138 Attributes:
139 mantissa: The mantissa part of floating number. Its range is 0.5 <=
140 abs(mantissa) < 1.0.
141 exponent: The exponent part of floating number.
142 """
143
144 def __init__(self, value, exp=0):
145 self.mantissa, self.exponent = math.frexp(value)
146 self.exponent = self.exponent + exp if value else 0
147
148 def __float__(self):
149 return math.ldexp(self.mantissa, self.exponent)
150
151 def __neg__(self):
152 return ExtendedFloat(-self.mantissa, self.exponent)
153
154 def __abs__(self):
155 return ExtendedFloat(abs(self.mantissa), self.exponent)
156
157 def __add__(self, other):
158 if not isinstance(other, ExtendedFloat):
159 other = ExtendedFloat(other)
160 if self.mantissa == 0:
161 return other
162 if self.exponent < other.exponent:
163 return other.__add__(self)
164 value = math.ldexp(other.mantissa, other.exponent - self.exponent)
165 return ExtendedFloat(value + self.mantissa, self.exponent)
166
167 __radd__ = __add__
168
169 def __pow__(self, p):
170 # Because 0.5 <= abs(mantissa) < 1.0, using float is enough without
171 # underflow for small p.
172 if p < -sys.float_info.min_exp:
173 return ExtendedFloat(self.mantissa**p, self.exponent * p)
174
175 half_pow = self.__pow__(p >> 1)
176 result = half_pow.__mul__(half_pow)
177 if p & 1:
178 result = result.__mul__(self)
179 return result
180
181 def __sub__(self, other):
182 return self.__add__(-other)
183
184 def __rsub__(self, other):
185 return ExtendedFloat(other).__add__(-self)
186
187 def __mul__(self, other):
188 if not isinstance(other, ExtendedFloat):
189 other = ExtendedFloat(other)
190 return ExtendedFloat(self.mantissa * other.mantissa,
191 self.exponent + other.exponent)
192
193 __rmul__ = __mul__
194
195 def __div__(self, other):
196 if not isinstance(other, ExtendedFloat):
197 other = ExtendedFloat(other)
198 return ExtendedFloat(self.mantissa / other.mantissa,
199 self.exponent - other.exponent)
200
201 def __rdiv__(self, other):
202 return ExtendedFloat(other).__div__(self)
203
204 def __repr__(self):
205 return "ExtendedFloat(%f, %d)" % (self.mantissa, self.exponent)
206
207 def __cmp__(self, other):
208 if not isinstance(other, ExtendedFloat):
209 other = ExtendedFloat(other)
210 diff = self - other
211 return cmp(diff.mantissa, 0)