blob: 6fc5f0a21c1dcc278e91dcdd8ee6f03ccdac92b6 [file] [log] [blame]
Kuang-che Wu6e4beca2018-06-27 17:45:02 +08001# -*- coding: utf-8 -*-
Kuang-che Wu88875db2017-07-20 10:47:53 +08002# Copyright 2017 The Chromium OS Authors. All rights reserved.
3# Use of this source code is governed by a BSD-style license that can be
4# found in the LICENSE file.
5"""Math utilities."""
6
7from __future__ import print_function
8
9import copy
Kuang-che Wud3ff5862019-11-26 11:41:05 +080010import functools
Kuang-che Wu88875db2017-07-20 10:47:53 +080011import math
12import sys
13
Kuang-che Wua7ddf9b2019-11-25 18:59:57 +080014import six
15
Kuang-che Wu88875db2017-07-20 10:47:53 +080016
Kuang-che Wu328da392018-09-12 17:20:13 +080017def log(x):
Kuang-che Wu88875db2017-07-20 10:47:53 +080018 """Wrapper of log(x) handling zero and negative value."""
19 # Due to arithmetic error, x may become tiny negative.
Kuang-che Wu328da392018-09-12 17:20:13 +080020 # -1e-10 is not enough, see strategy_test.test_next_idx_arithmetic_error for
21 # detail.
22 if 0 >= x > -1e-9:
Kuang-che Wu88875db2017-07-20 10:47:53 +080023 return 0.0
24 return math.log(x)
25
26
27def _xlogx(x):
Kuang-che Wu328da392018-09-12 17:20:13 +080028 return x * log(x)
Kuang-che Wu88875db2017-07-20 10:47:53 +080029
30
31def least(values):
32 """Finds minimum non-zero value.
33
34 Args:
35 values: Non-negative values.
36
37 Returns:
38 The minimum non-zero value.
39
40 Raises:
41 ValueError: if all values are zero or negative.
42 """
43 return min(x for x in values if x > 0)
44
45
Kuang-che Wu4f6f9122019-04-23 17:44:46 +080046def average(values):
47 """Calculates average of values.
48
49 Args:
50 values: numbers, at least one element
51
52 Raises:
53 ValueError: if empty
54 """
55 if not values:
56 raise ValueError('calculate average of empty list')
57 return float(sum(values)) / len(values)
58
59
Kuang-che Wu88875db2017-07-20 10:47:53 +080060class EntropyStats(object):
61 r"""A data structure to maintain cross entropy of a set of values.
62
63 This is a math helper for NoisyBinarySearch. Given a set of probability,
64 this class can incrementally calculate their cross entropy.
65
66 Algorithm:
67 Given unnormalized p_i, their cross entropy could be expressed as
68 Let S = \sum_i p_i
69 Normalized(p): p_i' = p_i/S
70 CrossEntropy(p) = \sum_i { -(p_i/S) * log (p_i/S) }
71 = \sum_i { -(p_i/S) (log p_i - log S) }
72 = \sum_i { -p_i/S * log p_i } + \sum_i { p_i/S * log S }
73 = -1/S * \sum_i { p_i*log p_i } + log S
74 = log S - 1/S * \sum_i { p_i*log p_i }
75
76 So we can maintain sum=|S| and sum_log=|\sum_i { p_i*log p_i }| for
77 incremental update.
Kuang-che Wua7ddf9b2019-11-25 18:59:57 +080078 """ # pylint: disable=docstring-trailing-quotes
Kuang-che Wu88875db2017-07-20 10:47:53 +080079
80 def __init__(self, p):
81 """Initializes EntropyStats.
82
83 Args:
Kuang-che Wuc90dbf12018-07-19 16:40:14 +080084 p: A list of probability value. Could be unnormalized.
Kuang-che Wu88875db2017-07-20 10:47:53 +080085 """
86 self.sum = sum(p)
Kuang-che Wuc89f2a22019-11-26 15:30:50 +080087 self.sum_log = sum(_xlogx(x) for x in p)
Kuang-che Wu88875db2017-07-20 10:47:53 +080088
89 def entropy(self):
90 """Returns cross entropy."""
91 if self.sum == 0:
92 return 0
Kuang-che Wu328da392018-09-12 17:20:13 +080093 return log(self.sum) - self.sum_log / self.sum
Kuang-che Wu88875db2017-07-20 10:47:53 +080094
95 def replace(self, old, new):
96 """Replaces one random variable in the collection with another.
97
98 Args:
99 old: original value to be replaced.
100 new: new value
101 """
102 self.sum += new - old
103 self.sum_log += _xlogx(new) - _xlogx(old)
104
105 def multiply(self, value):
106 """Multiplies all values in the collection by |value|.
107
108 Returns:
109 A new instance of EntropyStats with calculated value.
110 """
111 other = copy.copy(self)
112 # \sum_i { (p_i*value) * log (p_i*value) }
113 # = \sum_i { (p_i*value) * (log p_i + log value) }
114 # = value * \sum_i { p_i log p_i } + \sum_i { p_i * value * log value }
115 # = value * \sum_i { p_i log p_i } + \sum_i { p_i } * value * log value
116 other.sum_log = value * self.sum_log + self.sum * _xlogx(value)
117
118 other.sum *= value
119 return other
120
121
Kuang-che Wud3ff5862019-11-26 11:41:05 +0800122@functools.total_ordering
Kuang-che Wu88875db2017-07-20 10:47:53 +0800123class ExtendedFloat(object):
124 """Custom floating point number with larger range of exponent.
125
126 Problem:
127 In the formula of noisy binary search algorithm, we need to calculate p**n
128 (where n is test count) and normalize probabilities. When n is hundreds or
129 more, p**n goes too small and becomes zero (underflow), the normalization
130 step will fail.
131
132 Since n > 100 is not unreasonable large for noisy bisection, we need a
133 numeric type less likely or never underflow.
134
135 Solution:
136 Supports a floating number is represented as
137 x = mantissa * 2.**exponent
138 We use python's float to represent `mantissa` and python's int to represent
139 `exponent`.
140
141 For mantissa, we can accept calculated probability is not super accurate
142 because the worst case is just running test one or two more times. So we
143 keep mantissa in python's built-in float.
144
145 For exponent, we represent it using python's int, which is virtually
146 unlimited digits.
147
148 Alternative solutions:
149 1. Maybe we can rewrite the formula to avoid underflow. But I want to keep
150 the arithmetic code simple.
151
152 2. Use python stdlib's decimal or fractions. But they are very slow --
153 decimal is 6-50x slower, fractions is 9-760x slower while ExtendedFloat
154 is about 1.2-3.3x slower [1].
155
156 [1] test setup: easy N=1000, oracle=(0.01, 0.2)
157 hard N=1000, oracle=(0.45, 0.55)
158
159 Attributes:
160 mantissa: The mantissa part of floating number. Its range is 0.5 <=
161 abs(mantissa) < 1.0.
162 exponent: The exponent part of floating number.
163 """
164
165 def __init__(self, value, exp=0):
166 self.mantissa, self.exponent = math.frexp(value)
167 self.exponent = self.exponent + exp if value else 0
168
169 def __float__(self):
170 return math.ldexp(self.mantissa, self.exponent)
171
172 def __neg__(self):
173 return ExtendedFloat(-self.mantissa, self.exponent)
174
175 def __abs__(self):
176 return ExtendedFloat(abs(self.mantissa), self.exponent)
177
178 def __add__(self, other):
179 if not isinstance(other, ExtendedFloat):
180 other = ExtendedFloat(other)
181 if self.mantissa == 0:
182 return other
183 if self.exponent < other.exponent:
184 return other.__add__(self)
185 value = math.ldexp(other.mantissa, other.exponent - self.exponent)
186 return ExtendedFloat(value + self.mantissa, self.exponent)
187
188 __radd__ = __add__
189
190 def __pow__(self, p):
191 # Because 0.5 <= abs(mantissa) < 1.0, using float is enough without
192 # underflow for small p.
193 if p < -sys.float_info.min_exp:
194 return ExtendedFloat(self.mantissa**p, self.exponent * p)
195
196 half_pow = self.__pow__(p >> 1)
197 result = half_pow.__mul__(half_pow)
198 if p & 1:
199 result = result.__mul__(self)
200 return result
201
202 def __sub__(self, other):
203 return self.__add__(-other)
204
205 def __rsub__(self, other):
206 return ExtendedFloat(other).__add__(-self)
207
208 def __mul__(self, other):
209 if not isinstance(other, ExtendedFloat):
210 other = ExtendedFloat(other)
211 return ExtendedFloat(self.mantissa * other.mantissa,
212 self.exponent + other.exponent)
213
214 __rmul__ = __mul__
215
Kuang-che Wua7ddf9b2019-11-25 18:59:57 +0800216 def __truediv__(self, other):
Kuang-che Wu88875db2017-07-20 10:47:53 +0800217 if not isinstance(other, ExtendedFloat):
218 other = ExtendedFloat(other)
219 return ExtendedFloat(self.mantissa / other.mantissa,
220 self.exponent - other.exponent)
221
Kuang-che Wua7ddf9b2019-11-25 18:59:57 +0800222 def __rtruediv__(self, other):
Kuang-che Wud3ff5862019-11-26 11:41:05 +0800223 return ExtendedFloat(other).__truediv__(self)
Kuang-che Wu88875db2017-07-20 10:47:53 +0800224
Kuang-che Wua7ddf9b2019-11-25 18:59:57 +0800225 # TODO(kcwu): remove __div__ and __rdiv__ once we migrated to python3
226 if six.PY2:
227 __div__ = __truediv__
228 __rdiv__ = __rtruediv__
229
Kuang-che Wu88875db2017-07-20 10:47:53 +0800230 def __repr__(self):
Kuang-che Wuae6824b2019-08-27 22:20:01 +0800231 return 'ExtendedFloat(%f, %d)' % (self.mantissa, self.exponent)
Kuang-che Wu88875db2017-07-20 10:47:53 +0800232
Kuang-che Wud3ff5862019-11-26 11:41:05 +0800233 def __eq__(self, other):
Kuang-che Wu88875db2017-07-20 10:47:53 +0800234 if not isinstance(other, ExtendedFloat):
235 other = ExtendedFloat(other)
Kuang-che Wud3ff5862019-11-26 11:41:05 +0800236 return self.mantissa == other.mantissa and self.exponent == other.exponent
237
238 def __lt__(self, other):
239 if not isinstance(other, ExtendedFloat):
240 other = ExtendedFloat(other)
241 if self.mantissa * other.mantissa <= 0:
242 return self.mantissa < other.mantissa
243
244 if self.exponent == other.exponent:
245 return self.mantissa < other.mantissa
246
247 if self.mantissa > 0:
248 return self.exponent < other.exponent
249 return self.exponent > other.exponent
250
251 def __round__(self, ndigits):
252 return round(float(self), ndigits)