blob: f3d32442a3ef0e0e81b7b7d8869cae8644acd536 [file] [log] [blame]
Kuang-che Wu6e4beca2018-06-27 17:45:02 +08001# -*- coding: utf-8 -*-
Kuang-che Wu88875db2017-07-20 10:47:53 +08002# Copyright 2017 The Chromium OS Authors. All rights reserved.
3# Use of this source code is governed by a BSD-style license that can be
4# found in the LICENSE file.
5"""Math utilities."""
6
7from __future__ import print_function
8
9import copy
Kuang-che Wud3ff5862019-11-26 11:41:05 +080010import functools
Kuang-che Wu88875db2017-07-20 10:47:53 +080011import math
12import sys
13
14
Kuang-che Wu328da392018-09-12 17:20:13 +080015def log(x):
Kuang-che Wu88875db2017-07-20 10:47:53 +080016 """Wrapper of log(x) handling zero and negative value."""
17 # Due to arithmetic error, x may become tiny negative.
Kuang-che Wu328da392018-09-12 17:20:13 +080018 # -1e-10 is not enough, see strategy_test.test_next_idx_arithmetic_error for
19 # detail.
20 if 0 >= x > -1e-9:
Kuang-che Wu88875db2017-07-20 10:47:53 +080021 return 0.0
22 return math.log(x)
23
24
25def _xlogx(x):
Kuang-che Wu328da392018-09-12 17:20:13 +080026 return x * log(x)
Kuang-che Wu88875db2017-07-20 10:47:53 +080027
28
29def least(values):
30 """Finds minimum non-zero value.
31
32 Args:
33 values: Non-negative values.
34
35 Returns:
36 The minimum non-zero value.
37
38 Raises:
39 ValueError: if all values are zero or negative.
40 """
41 return min(x for x in values if x > 0)
42
43
Kuang-che Wu4f6f9122019-04-23 17:44:46 +080044def average(values):
45 """Calculates average of values.
46
47 Args:
48 values: numbers, at least one element
49
50 Raises:
51 ValueError: if empty
52 """
53 if not values:
54 raise ValueError('calculate average of empty list')
55 return float(sum(values)) / len(values)
56
57
Kuang-che Wu0dc7fc42020-12-11 22:15:49 +080058class Averager:
59 """Helper class to calculate average."""
60
61 def __init__(self):
62 self.count = 0
63 self.total = 0
64
65 def add(self, value):
66 self.count += 1
67 self.total += value
68
69 def average(self):
70 if self.count == 0:
71 raise ValueError('calculate average of empty list')
72 return self.total / self.count
73
74
Kuang-che Wu23192ad2020-03-11 18:12:46 +080075class EntropyStats:
Kuang-che Wu88875db2017-07-20 10:47:53 +080076 r"""A data structure to maintain cross entropy of a set of values.
77
78 This is a math helper for NoisyBinarySearch. Given a set of probability,
79 this class can incrementally calculate their cross entropy.
80
81 Algorithm:
82 Given unnormalized p_i, their cross entropy could be expressed as
83 Let S = \sum_i p_i
84 Normalized(p): p_i' = p_i/S
85 CrossEntropy(p) = \sum_i { -(p_i/S) * log (p_i/S) }
86 = \sum_i { -(p_i/S) (log p_i - log S) }
87 = \sum_i { -p_i/S * log p_i } + \sum_i { p_i/S * log S }
88 = -1/S * \sum_i { p_i*log p_i } + log S
89 = log S - 1/S * \sum_i { p_i*log p_i }
90
91 So we can maintain sum=|S| and sum_log=|\sum_i { p_i*log p_i }| for
92 incremental update.
Kuang-che Wua7ddf9b2019-11-25 18:59:57 +080093 """ # pylint: disable=docstring-trailing-quotes
Kuang-che Wu88875db2017-07-20 10:47:53 +080094
95 def __init__(self, p):
96 """Initializes EntropyStats.
97
98 Args:
Kuang-che Wuc90dbf12018-07-19 16:40:14 +080099 p: A list of probability value. Could be unnormalized.
Kuang-che Wu88875db2017-07-20 10:47:53 +0800100 """
101 self.sum = sum(p)
Kuang-che Wuc89f2a22019-11-26 15:30:50 +0800102 self.sum_log = sum(_xlogx(x) for x in p)
Kuang-che Wu88875db2017-07-20 10:47:53 +0800103
104 def entropy(self):
105 """Returns cross entropy."""
106 if self.sum == 0:
107 return 0
Kuang-che Wu328da392018-09-12 17:20:13 +0800108 return log(self.sum) - self.sum_log / self.sum
Kuang-che Wu88875db2017-07-20 10:47:53 +0800109
110 def replace(self, old, new):
111 """Replaces one random variable in the collection with another.
112
113 Args:
114 old: original value to be replaced.
115 new: new value
116 """
117 self.sum += new - old
118 self.sum_log += _xlogx(new) - _xlogx(old)
119
120 def multiply(self, value):
121 """Multiplies all values in the collection by |value|.
122
123 Returns:
124 A new instance of EntropyStats with calculated value.
125 """
126 other = copy.copy(self)
127 # \sum_i { (p_i*value) * log (p_i*value) }
128 # = \sum_i { (p_i*value) * (log p_i + log value) }
129 # = value * \sum_i { p_i log p_i } + \sum_i { p_i * value * log value }
130 # = value * \sum_i { p_i log p_i } + \sum_i { p_i } * value * log value
131 other.sum_log = value * self.sum_log + self.sum * _xlogx(value)
132
133 other.sum *= value
134 return other
135
136
Kuang-che Wud3ff5862019-11-26 11:41:05 +0800137@functools.total_ordering
Kuang-che Wu23192ad2020-03-11 18:12:46 +0800138class ExtendedFloat:
Kuang-che Wu88875db2017-07-20 10:47:53 +0800139 """Custom floating point number with larger range of exponent.
140
141 Problem:
142 In the formula of noisy binary search algorithm, we need to calculate p**n
143 (where n is test count) and normalize probabilities. When n is hundreds or
144 more, p**n goes too small and becomes zero (underflow), the normalization
145 step will fail.
146
147 Since n > 100 is not unreasonable large for noisy bisection, we need a
148 numeric type less likely or never underflow.
149
150 Solution:
151 Supports a floating number is represented as
152 x = mantissa * 2.**exponent
153 We use python's float to represent `mantissa` and python's int to represent
154 `exponent`.
155
156 For mantissa, we can accept calculated probability is not super accurate
157 because the worst case is just running test one or two more times. So we
158 keep mantissa in python's built-in float.
159
160 For exponent, we represent it using python's int, which is virtually
161 unlimited digits.
162
163 Alternative solutions:
164 1. Maybe we can rewrite the formula to avoid underflow. But I want to keep
165 the arithmetic code simple.
166
167 2. Use python stdlib's decimal or fractions. But they are very slow --
168 decimal is 6-50x slower, fractions is 9-760x slower while ExtendedFloat
169 is about 1.2-3.3x slower [1].
170
171 [1] test setup: easy N=1000, oracle=(0.01, 0.2)
172 hard N=1000, oracle=(0.45, 0.55)
173
174 Attributes:
175 mantissa: The mantissa part of floating number. Its range is 0.5 <=
176 abs(mantissa) < 1.0.
177 exponent: The exponent part of floating number.
178 """
179
180 def __init__(self, value, exp=0):
181 self.mantissa, self.exponent = math.frexp(value)
182 self.exponent = self.exponent + exp if value else 0
183
184 def __float__(self):
185 return math.ldexp(self.mantissa, self.exponent)
186
187 def __neg__(self):
188 return ExtendedFloat(-self.mantissa, self.exponent)
189
190 def __abs__(self):
191 return ExtendedFloat(abs(self.mantissa), self.exponent)
192
193 def __add__(self, other):
194 if not isinstance(other, ExtendedFloat):
195 other = ExtendedFloat(other)
196 if self.mantissa == 0:
197 return other
198 if self.exponent < other.exponent:
199 return other.__add__(self)
200 value = math.ldexp(other.mantissa, other.exponent - self.exponent)
201 return ExtendedFloat(value + self.mantissa, self.exponent)
202
203 __radd__ = __add__
204
205 def __pow__(self, p):
206 # Because 0.5 <= abs(mantissa) < 1.0, using float is enough without
207 # underflow for small p.
208 if p < -sys.float_info.min_exp:
209 return ExtendedFloat(self.mantissa**p, self.exponent * p)
210
211 half_pow = self.__pow__(p >> 1)
212 result = half_pow.__mul__(half_pow)
213 if p & 1:
214 result = result.__mul__(self)
215 return result
216
217 def __sub__(self, other):
218 return self.__add__(-other)
219
220 def __rsub__(self, other):
221 return ExtendedFloat(other).__add__(-self)
222
223 def __mul__(self, other):
224 if not isinstance(other, ExtendedFloat):
225 other = ExtendedFloat(other)
226 return ExtendedFloat(self.mantissa * other.mantissa,
227 self.exponent + other.exponent)
228
229 __rmul__ = __mul__
230
Kuang-che Wua7ddf9b2019-11-25 18:59:57 +0800231 def __truediv__(self, other):
Kuang-che Wu88875db2017-07-20 10:47:53 +0800232 if not isinstance(other, ExtendedFloat):
233 other = ExtendedFloat(other)
234 return ExtendedFloat(self.mantissa / other.mantissa,
235 self.exponent - other.exponent)
236
Kuang-che Wua7ddf9b2019-11-25 18:59:57 +0800237 def __rtruediv__(self, other):
Kuang-che Wud3ff5862019-11-26 11:41:05 +0800238 return ExtendedFloat(other).__truediv__(self)
Kuang-che Wu88875db2017-07-20 10:47:53 +0800239
240 def __repr__(self):
Kuang-che Wuae6824b2019-08-27 22:20:01 +0800241 return 'ExtendedFloat(%f, %d)' % (self.mantissa, self.exponent)
Kuang-che Wu88875db2017-07-20 10:47:53 +0800242
Kuang-che Wud3ff5862019-11-26 11:41:05 +0800243 def __eq__(self, other):
Kuang-che Wu88875db2017-07-20 10:47:53 +0800244 if not isinstance(other, ExtendedFloat):
245 other = ExtendedFloat(other)
Kuang-che Wud3ff5862019-11-26 11:41:05 +0800246 return self.mantissa == other.mantissa and self.exponent == other.exponent
247
248 def __lt__(self, other):
249 if not isinstance(other, ExtendedFloat):
250 other = ExtendedFloat(other)
251 if self.mantissa * other.mantissa <= 0:
252 return self.mantissa < other.mantissa
253
254 if self.exponent == other.exponent:
255 return self.mantissa < other.mantissa
256
257 if self.mantissa > 0:
258 return self.exponent < other.exponent
259 return self.exponent > other.exponent
260
Kuang-che Wu020a1182020-09-08 17:17:22 +0800261 def __round__(self, digits):
262 return round(float(self), digits)