blob: 006f6ecadea3b6fa9b12e723067fedebf7aa643e [file] [log] [blame]
Kuang-che Wu6e4beca2018-06-27 17:45:02 +08001# -*- coding: utf-8 -*-
Kuang-che Wu88875db2017-07-20 10:47:53 +08002# Copyright 2017 The Chromium OS Authors. All rights reserved.
3# Use of this source code is governed by a BSD-style license that can be
4# found in the LICENSE file.
5"""Math utilities."""
6
7from __future__ import print_function
8
9import copy
10import math
11import sys
12
13
Kuang-che Wu328da392018-09-12 17:20:13 +080014def log(x):
Kuang-che Wu88875db2017-07-20 10:47:53 +080015 """Wrapper of log(x) handling zero and negative value."""
16 # Due to arithmetic error, x may become tiny negative.
Kuang-che Wu328da392018-09-12 17:20:13 +080017 # -1e-10 is not enough, see strategy_test.test_next_idx_arithmetic_error for
18 # detail.
19 if 0 >= x > -1e-9:
Kuang-che Wu88875db2017-07-20 10:47:53 +080020 return 0.0
21 return math.log(x)
22
23
24def _xlogx(x):
Kuang-che Wu328da392018-09-12 17:20:13 +080025 return x * log(x)
Kuang-che Wu88875db2017-07-20 10:47:53 +080026
27
28def least(values):
29 """Finds minimum non-zero value.
30
31 Args:
32 values: Non-negative values.
33
34 Returns:
35 The minimum non-zero value.
36
37 Raises:
38 ValueError: if all values are zero or negative.
39 """
40 return min(x for x in values if x > 0)
41
42
Kuang-che Wu4f6f9122019-04-23 17:44:46 +080043def average(values):
44 """Calculates average of values.
45
46 Args:
47 values: numbers, at least one element
48
49 Raises:
50 ValueError: if empty
51 """
52 if not values:
53 raise ValueError('calculate average of empty list')
54 return float(sum(values)) / len(values)
55
56
Kuang-che Wu88875db2017-07-20 10:47:53 +080057class EntropyStats(object):
58 r"""A data structure to maintain cross entropy of a set of values.
59
60 This is a math helper for NoisyBinarySearch. Given a set of probability,
61 this class can incrementally calculate their cross entropy.
62
63 Algorithm:
64 Given unnormalized p_i, their cross entropy could be expressed as
65 Let S = \sum_i p_i
66 Normalized(p): p_i' = p_i/S
67 CrossEntropy(p) = \sum_i { -(p_i/S) * log (p_i/S) }
68 = \sum_i { -(p_i/S) (log p_i - log S) }
69 = \sum_i { -p_i/S * log p_i } + \sum_i { p_i/S * log S }
70 = -1/S * \sum_i { p_i*log p_i } + log S
71 = log S - 1/S * \sum_i { p_i*log p_i }
72
73 So we can maintain sum=|S| and sum_log=|\sum_i { p_i*log p_i }| for
74 incremental update.
75 """
76
77 def __init__(self, p):
78 """Initializes EntropyStats.
79
80 Args:
Kuang-che Wuc90dbf12018-07-19 16:40:14 +080081 p: A list of probability value. Could be unnormalized.
Kuang-che Wu88875db2017-07-20 10:47:53 +080082 """
83 self.sum = sum(p)
84 self.sum_log = sum(map(_xlogx, p))
85
86 def entropy(self):
87 """Returns cross entropy."""
88 if self.sum == 0:
89 return 0
Kuang-che Wu328da392018-09-12 17:20:13 +080090 return log(self.sum) - self.sum_log / self.sum
Kuang-che Wu88875db2017-07-20 10:47:53 +080091
92 def replace(self, old, new):
93 """Replaces one random variable in the collection with another.
94
95 Args:
96 old: original value to be replaced.
97 new: new value
98 """
99 self.sum += new - old
100 self.sum_log += _xlogx(new) - _xlogx(old)
101
102 def multiply(self, value):
103 """Multiplies all values in the collection by |value|.
104
105 Returns:
106 A new instance of EntropyStats with calculated value.
107 """
108 other = copy.copy(self)
109 # \sum_i { (p_i*value) * log (p_i*value) }
110 # = \sum_i { (p_i*value) * (log p_i + log value) }
111 # = value * \sum_i { p_i log p_i } + \sum_i { p_i * value * log value }
112 # = value * \sum_i { p_i log p_i } + \sum_i { p_i } * value * log value
113 other.sum_log = value * self.sum_log + self.sum * _xlogx(value)
114
115 other.sum *= value
116 return other
117
118
119class ExtendedFloat(object):
120 """Custom floating point number with larger range of exponent.
121
122 Problem:
123 In the formula of noisy binary search algorithm, we need to calculate p**n
124 (where n is test count) and normalize probabilities. When n is hundreds or
125 more, p**n goes too small and becomes zero (underflow), the normalization
126 step will fail.
127
128 Since n > 100 is not unreasonable large for noisy bisection, we need a
129 numeric type less likely or never underflow.
130
131 Solution:
132 Supports a floating number is represented as
133 x = mantissa * 2.**exponent
134 We use python's float to represent `mantissa` and python's int to represent
135 `exponent`.
136
137 For mantissa, we can accept calculated probability is not super accurate
138 because the worst case is just running test one or two more times. So we
139 keep mantissa in python's built-in float.
140
141 For exponent, we represent it using python's int, which is virtually
142 unlimited digits.
143
144 Alternative solutions:
145 1. Maybe we can rewrite the formula to avoid underflow. But I want to keep
146 the arithmetic code simple.
147
148 2. Use python stdlib's decimal or fractions. But they are very slow --
149 decimal is 6-50x slower, fractions is 9-760x slower while ExtendedFloat
150 is about 1.2-3.3x slower [1].
151
152 [1] test setup: easy N=1000, oracle=(0.01, 0.2)
153 hard N=1000, oracle=(0.45, 0.55)
154
155 Attributes:
156 mantissa: The mantissa part of floating number. Its range is 0.5 <=
157 abs(mantissa) < 1.0.
158 exponent: The exponent part of floating number.
159 """
160
161 def __init__(self, value, exp=0):
162 self.mantissa, self.exponent = math.frexp(value)
163 self.exponent = self.exponent + exp if value else 0
164
165 def __float__(self):
166 return math.ldexp(self.mantissa, self.exponent)
167
168 def __neg__(self):
169 return ExtendedFloat(-self.mantissa, self.exponent)
170
171 def __abs__(self):
172 return ExtendedFloat(abs(self.mantissa), self.exponent)
173
174 def __add__(self, other):
175 if not isinstance(other, ExtendedFloat):
176 other = ExtendedFloat(other)
177 if self.mantissa == 0:
178 return other
179 if self.exponent < other.exponent:
180 return other.__add__(self)
181 value = math.ldexp(other.mantissa, other.exponent - self.exponent)
182 return ExtendedFloat(value + self.mantissa, self.exponent)
183
184 __radd__ = __add__
185
186 def __pow__(self, p):
187 # Because 0.5 <= abs(mantissa) < 1.0, using float is enough without
188 # underflow for small p.
189 if p < -sys.float_info.min_exp:
190 return ExtendedFloat(self.mantissa**p, self.exponent * p)
191
192 half_pow = self.__pow__(p >> 1)
193 result = half_pow.__mul__(half_pow)
194 if p & 1:
195 result = result.__mul__(self)
196 return result
197
198 def __sub__(self, other):
199 return self.__add__(-other)
200
201 def __rsub__(self, other):
202 return ExtendedFloat(other).__add__(-self)
203
204 def __mul__(self, other):
205 if not isinstance(other, ExtendedFloat):
206 other = ExtendedFloat(other)
207 return ExtendedFloat(self.mantissa * other.mantissa,
208 self.exponent + other.exponent)
209
210 __rmul__ = __mul__
211
212 def __div__(self, other):
213 if not isinstance(other, ExtendedFloat):
214 other = ExtendedFloat(other)
215 return ExtendedFloat(self.mantissa / other.mantissa,
216 self.exponent - other.exponent)
217
218 def __rdiv__(self, other):
219 return ExtendedFloat(other).__div__(self)
220
221 def __repr__(self):
Kuang-che Wuae6824b2019-08-27 22:20:01 +0800222 return 'ExtendedFloat(%f, %d)' % (self.mantissa, self.exponent)
Kuang-che Wu88875db2017-07-20 10:47:53 +0800223
224 def __cmp__(self, other):
225 if not isinstance(other, ExtendedFloat):
226 other = ExtendedFloat(other)
227 diff = self - other
228 return cmp(diff.mantissa, 0)