blob: 6e7024d4d3f259fc3f6146fc595fa18bfb0d1e78 [file] [log] [blame]
Kuang-che Wu6e4beca2018-06-27 17:45:02 +08001# -*- coding: utf-8 -*-
Kuang-che Wu88875db2017-07-20 10:47:53 +08002# Copyright 2017 The Chromium OS Authors. All rights reserved.
3# Use of this source code is governed by a BSD-style license that can be
4# found in the LICENSE file.
5"""Math utilities."""
6
7from __future__ import print_function
8
9import copy
10import math
11import sys
12
Kuang-che Wua7ddf9b2019-11-25 18:59:57 +080013import six
14
Kuang-che Wu88875db2017-07-20 10:47:53 +080015
Kuang-che Wu328da392018-09-12 17:20:13 +080016def log(x):
Kuang-che Wu88875db2017-07-20 10:47:53 +080017 """Wrapper of log(x) handling zero and negative value."""
18 # Due to arithmetic error, x may become tiny negative.
Kuang-che Wu328da392018-09-12 17:20:13 +080019 # -1e-10 is not enough, see strategy_test.test_next_idx_arithmetic_error for
20 # detail.
21 if 0 >= x > -1e-9:
Kuang-che Wu88875db2017-07-20 10:47:53 +080022 return 0.0
23 return math.log(x)
24
25
26def _xlogx(x):
Kuang-che Wu328da392018-09-12 17:20:13 +080027 return x * log(x)
Kuang-che Wu88875db2017-07-20 10:47:53 +080028
29
30def least(values):
31 """Finds minimum non-zero value.
32
33 Args:
34 values: Non-negative values.
35
36 Returns:
37 The minimum non-zero value.
38
39 Raises:
40 ValueError: if all values are zero or negative.
41 """
42 return min(x for x in values if x > 0)
43
44
Kuang-che Wu4f6f9122019-04-23 17:44:46 +080045def average(values):
46 """Calculates average of values.
47
48 Args:
49 values: numbers, at least one element
50
51 Raises:
52 ValueError: if empty
53 """
54 if not values:
55 raise ValueError('calculate average of empty list')
56 return float(sum(values)) / len(values)
57
58
Kuang-che Wu88875db2017-07-20 10:47:53 +080059class EntropyStats(object):
60 r"""A data structure to maintain cross entropy of a set of values.
61
62 This is a math helper for NoisyBinarySearch. Given a set of probability,
63 this class can incrementally calculate their cross entropy.
64
65 Algorithm:
66 Given unnormalized p_i, their cross entropy could be expressed as
67 Let S = \sum_i p_i
68 Normalized(p): p_i' = p_i/S
69 CrossEntropy(p) = \sum_i { -(p_i/S) * log (p_i/S) }
70 = \sum_i { -(p_i/S) (log p_i - log S) }
71 = \sum_i { -p_i/S * log p_i } + \sum_i { p_i/S * log S }
72 = -1/S * \sum_i { p_i*log p_i } + log S
73 = log S - 1/S * \sum_i { p_i*log p_i }
74
75 So we can maintain sum=|S| and sum_log=|\sum_i { p_i*log p_i }| for
76 incremental update.
Kuang-che Wua7ddf9b2019-11-25 18:59:57 +080077 """ # pylint: disable=docstring-trailing-quotes
Kuang-che Wu88875db2017-07-20 10:47:53 +080078
79 def __init__(self, p):
80 """Initializes EntropyStats.
81
82 Args:
Kuang-che Wuc90dbf12018-07-19 16:40:14 +080083 p: A list of probability value. Could be unnormalized.
Kuang-che Wu88875db2017-07-20 10:47:53 +080084 """
85 self.sum = sum(p)
86 self.sum_log = sum(map(_xlogx, p))
87
88 def entropy(self):
89 """Returns cross entropy."""
90 if self.sum == 0:
91 return 0
Kuang-che Wu328da392018-09-12 17:20:13 +080092 return log(self.sum) - self.sum_log / self.sum
Kuang-che Wu88875db2017-07-20 10:47:53 +080093
94 def replace(self, old, new):
95 """Replaces one random variable in the collection with another.
96
97 Args:
98 old: original value to be replaced.
99 new: new value
100 """
101 self.sum += new - old
102 self.sum_log += _xlogx(new) - _xlogx(old)
103
104 def multiply(self, value):
105 """Multiplies all values in the collection by |value|.
106
107 Returns:
108 A new instance of EntropyStats with calculated value.
109 """
110 other = copy.copy(self)
111 # \sum_i { (p_i*value) * log (p_i*value) }
112 # = \sum_i { (p_i*value) * (log p_i + log value) }
113 # = value * \sum_i { p_i log p_i } + \sum_i { p_i * value * log value }
114 # = value * \sum_i { p_i log p_i } + \sum_i { p_i } * value * log value
115 other.sum_log = value * self.sum_log + self.sum * _xlogx(value)
116
117 other.sum *= value
118 return other
119
120
121class ExtendedFloat(object):
122 """Custom floating point number with larger range of exponent.
123
124 Problem:
125 In the formula of noisy binary search algorithm, we need to calculate p**n
126 (where n is test count) and normalize probabilities. When n is hundreds or
127 more, p**n goes too small and becomes zero (underflow), the normalization
128 step will fail.
129
130 Since n > 100 is not unreasonable large for noisy bisection, we need a
131 numeric type less likely or never underflow.
132
133 Solution:
134 Supports a floating number is represented as
135 x = mantissa * 2.**exponent
136 We use python's float to represent `mantissa` and python's int to represent
137 `exponent`.
138
139 For mantissa, we can accept calculated probability is not super accurate
140 because the worst case is just running test one or two more times. So we
141 keep mantissa in python's built-in float.
142
143 For exponent, we represent it using python's int, which is virtually
144 unlimited digits.
145
146 Alternative solutions:
147 1. Maybe we can rewrite the formula to avoid underflow. But I want to keep
148 the arithmetic code simple.
149
150 2. Use python stdlib's decimal or fractions. But they are very slow --
151 decimal is 6-50x slower, fractions is 9-760x slower while ExtendedFloat
152 is about 1.2-3.3x slower [1].
153
154 [1] test setup: easy N=1000, oracle=(0.01, 0.2)
155 hard N=1000, oracle=(0.45, 0.55)
156
157 Attributes:
158 mantissa: The mantissa part of floating number. Its range is 0.5 <=
159 abs(mantissa) < 1.0.
160 exponent: The exponent part of floating number.
161 """
162
163 def __init__(self, value, exp=0):
164 self.mantissa, self.exponent = math.frexp(value)
165 self.exponent = self.exponent + exp if value else 0
166
167 def __float__(self):
168 return math.ldexp(self.mantissa, self.exponent)
169
170 def __neg__(self):
171 return ExtendedFloat(-self.mantissa, self.exponent)
172
173 def __abs__(self):
174 return ExtendedFloat(abs(self.mantissa), self.exponent)
175
176 def __add__(self, other):
177 if not isinstance(other, ExtendedFloat):
178 other = ExtendedFloat(other)
179 if self.mantissa == 0:
180 return other
181 if self.exponent < other.exponent:
182 return other.__add__(self)
183 value = math.ldexp(other.mantissa, other.exponent - self.exponent)
184 return ExtendedFloat(value + self.mantissa, self.exponent)
185
186 __radd__ = __add__
187
188 def __pow__(self, p):
189 # Because 0.5 <= abs(mantissa) < 1.0, using float is enough without
190 # underflow for small p.
191 if p < -sys.float_info.min_exp:
192 return ExtendedFloat(self.mantissa**p, self.exponent * p)
193
194 half_pow = self.__pow__(p >> 1)
195 result = half_pow.__mul__(half_pow)
196 if p & 1:
197 result = result.__mul__(self)
198 return result
199
200 def __sub__(self, other):
201 return self.__add__(-other)
202
203 def __rsub__(self, other):
204 return ExtendedFloat(other).__add__(-self)
205
206 def __mul__(self, other):
207 if not isinstance(other, ExtendedFloat):
208 other = ExtendedFloat(other)
209 return ExtendedFloat(self.mantissa * other.mantissa,
210 self.exponent + other.exponent)
211
212 __rmul__ = __mul__
213
Kuang-che Wua7ddf9b2019-11-25 18:59:57 +0800214 def __truediv__(self, other):
Kuang-che Wu88875db2017-07-20 10:47:53 +0800215 if not isinstance(other, ExtendedFloat):
216 other = ExtendedFloat(other)
217 return ExtendedFloat(self.mantissa / other.mantissa,
218 self.exponent - other.exponent)
219
Kuang-che Wua7ddf9b2019-11-25 18:59:57 +0800220 def __rtruediv__(self, other):
Kuang-che Wu88875db2017-07-20 10:47:53 +0800221 return ExtendedFloat(other).__div__(self)
222
Kuang-che Wua7ddf9b2019-11-25 18:59:57 +0800223 # TODO(kcwu): remove __div__ and __rdiv__ once we migrated to python3
224 if six.PY2:
225 __div__ = __truediv__
226 __rdiv__ = __rtruediv__
227
Kuang-che Wu88875db2017-07-20 10:47:53 +0800228 def __repr__(self):
Kuang-che Wuae6824b2019-08-27 22:20:01 +0800229 return 'ExtendedFloat(%f, %d)' % (self.mantissa, self.exponent)
Kuang-che Wu88875db2017-07-20 10:47:53 +0800230
231 def __cmp__(self, other):
232 if not isinstance(other, ExtendedFloat):
233 other = ExtendedFloat(other)
234 diff = self - other
235 return cmp(diff.mantissa, 0)