blob: 46c76e5a0bfd265fd7c801246553edc36ae1e162 [file] [log] [blame]
Kuang-che Wu6e4beca2018-06-27 17:45:02 +08001# -*- coding: utf-8 -*-
Kuang-che Wu88875db2017-07-20 10:47:53 +08002# Copyright 2017 The Chromium OS Authors. All rights reserved.
3# Use of this source code is governed by a BSD-style license that can be
4# found in the LICENSE file.
5"""Math utilities."""
6
7from __future__ import print_function
8
9import copy
10import math
11import sys
12
13
Kuang-che Wu328da392018-09-12 17:20:13 +080014def log(x):
Kuang-che Wu88875db2017-07-20 10:47:53 +080015 """Wrapper of log(x) handling zero and negative value."""
16 # Due to arithmetic error, x may become tiny negative.
Kuang-che Wu328da392018-09-12 17:20:13 +080017 # -1e-10 is not enough, see strategy_test.test_next_idx_arithmetic_error for
18 # detail.
19 if 0 >= x > -1e-9:
Kuang-che Wu88875db2017-07-20 10:47:53 +080020 return 0.0
21 return math.log(x)
22
23
24def _xlogx(x):
Kuang-che Wu328da392018-09-12 17:20:13 +080025 return x * log(x)
Kuang-che Wu88875db2017-07-20 10:47:53 +080026
27
28def least(values):
29 """Finds minimum non-zero value.
30
31 Args:
32 values: Non-negative values.
33
34 Returns:
35 The minimum non-zero value.
36
37 Raises:
38 ValueError: if all values are zero or negative.
39 """
40 return min(x for x in values if x > 0)
41
42
43class EntropyStats(object):
44 r"""A data structure to maintain cross entropy of a set of values.
45
46 This is a math helper for NoisyBinarySearch. Given a set of probability,
47 this class can incrementally calculate their cross entropy.
48
49 Algorithm:
50 Given unnormalized p_i, their cross entropy could be expressed as
51 Let S = \sum_i p_i
52 Normalized(p): p_i' = p_i/S
53 CrossEntropy(p) = \sum_i { -(p_i/S) * log (p_i/S) }
54 = \sum_i { -(p_i/S) (log p_i - log S) }
55 = \sum_i { -p_i/S * log p_i } + \sum_i { p_i/S * log S }
56 = -1/S * \sum_i { p_i*log p_i } + log S
57 = log S - 1/S * \sum_i { p_i*log p_i }
58
59 So we can maintain sum=|S| and sum_log=|\sum_i { p_i*log p_i }| for
60 incremental update.
61 """
62
63 def __init__(self, p):
64 """Initializes EntropyStats.
65
66 Args:
Kuang-che Wuc90dbf12018-07-19 16:40:14 +080067 p: A list of probability value. Could be unnormalized.
Kuang-che Wu88875db2017-07-20 10:47:53 +080068 """
69 self.sum = sum(p)
70 self.sum_log = sum(map(_xlogx, p))
71
72 def entropy(self):
73 """Returns cross entropy."""
74 if self.sum == 0:
75 return 0
Kuang-che Wu328da392018-09-12 17:20:13 +080076 return log(self.sum) - self.sum_log / self.sum
Kuang-che Wu88875db2017-07-20 10:47:53 +080077
78 def replace(self, old, new):
79 """Replaces one random variable in the collection with another.
80
81 Args:
82 old: original value to be replaced.
83 new: new value
84 """
85 self.sum += new - old
86 self.sum_log += _xlogx(new) - _xlogx(old)
87
88 def multiply(self, value):
89 """Multiplies all values in the collection by |value|.
90
91 Returns:
92 A new instance of EntropyStats with calculated value.
93 """
94 other = copy.copy(self)
95 # \sum_i { (p_i*value) * log (p_i*value) }
96 # = \sum_i { (p_i*value) * (log p_i + log value) }
97 # = value * \sum_i { p_i log p_i } + \sum_i { p_i * value * log value }
98 # = value * \sum_i { p_i log p_i } + \sum_i { p_i } * value * log value
99 other.sum_log = value * self.sum_log + self.sum * _xlogx(value)
100
101 other.sum *= value
102 return other
103
104
105class ExtendedFloat(object):
106 """Custom floating point number with larger range of exponent.
107
108 Problem:
109 In the formula of noisy binary search algorithm, we need to calculate p**n
110 (where n is test count) and normalize probabilities. When n is hundreds or
111 more, p**n goes too small and becomes zero (underflow), the normalization
112 step will fail.
113
114 Since n > 100 is not unreasonable large for noisy bisection, we need a
115 numeric type less likely or never underflow.
116
117 Solution:
118 Supports a floating number is represented as
119 x = mantissa * 2.**exponent
120 We use python's float to represent `mantissa` and python's int to represent
121 `exponent`.
122
123 For mantissa, we can accept calculated probability is not super accurate
124 because the worst case is just running test one or two more times. So we
125 keep mantissa in python's built-in float.
126
127 For exponent, we represent it using python's int, which is virtually
128 unlimited digits.
129
130 Alternative solutions:
131 1. Maybe we can rewrite the formula to avoid underflow. But I want to keep
132 the arithmetic code simple.
133
134 2. Use python stdlib's decimal or fractions. But they are very slow --
135 decimal is 6-50x slower, fractions is 9-760x slower while ExtendedFloat
136 is about 1.2-3.3x slower [1].
137
138 [1] test setup: easy N=1000, oracle=(0.01, 0.2)
139 hard N=1000, oracle=(0.45, 0.55)
140
141 Attributes:
142 mantissa: The mantissa part of floating number. Its range is 0.5 <=
143 abs(mantissa) < 1.0.
144 exponent: The exponent part of floating number.
145 """
146
147 def __init__(self, value, exp=0):
148 self.mantissa, self.exponent = math.frexp(value)
149 self.exponent = self.exponent + exp if value else 0
150
151 def __float__(self):
152 return math.ldexp(self.mantissa, self.exponent)
153
154 def __neg__(self):
155 return ExtendedFloat(-self.mantissa, self.exponent)
156
157 def __abs__(self):
158 return ExtendedFloat(abs(self.mantissa), self.exponent)
159
160 def __add__(self, other):
161 if not isinstance(other, ExtendedFloat):
162 other = ExtendedFloat(other)
163 if self.mantissa == 0:
164 return other
165 if self.exponent < other.exponent:
166 return other.__add__(self)
167 value = math.ldexp(other.mantissa, other.exponent - self.exponent)
168 return ExtendedFloat(value + self.mantissa, self.exponent)
169
170 __radd__ = __add__
171
172 def __pow__(self, p):
173 # Because 0.5 <= abs(mantissa) < 1.0, using float is enough without
174 # underflow for small p.
175 if p < -sys.float_info.min_exp:
176 return ExtendedFloat(self.mantissa**p, self.exponent * p)
177
178 half_pow = self.__pow__(p >> 1)
179 result = half_pow.__mul__(half_pow)
180 if p & 1:
181 result = result.__mul__(self)
182 return result
183
184 def __sub__(self, other):
185 return self.__add__(-other)
186
187 def __rsub__(self, other):
188 return ExtendedFloat(other).__add__(-self)
189
190 def __mul__(self, other):
191 if not isinstance(other, ExtendedFloat):
192 other = ExtendedFloat(other)
193 return ExtendedFloat(self.mantissa * other.mantissa,
194 self.exponent + other.exponent)
195
196 __rmul__ = __mul__
197
198 def __div__(self, other):
199 if not isinstance(other, ExtendedFloat):
200 other = ExtendedFloat(other)
201 return ExtendedFloat(self.mantissa / other.mantissa,
202 self.exponent - other.exponent)
203
204 def __rdiv__(self, other):
205 return ExtendedFloat(other).__div__(self)
206
207 def __repr__(self):
208 return "ExtendedFloat(%f, %d)" % (self.mantissa, self.exponent)
209
210 def __cmp__(self, other):
211 if not isinstance(other, ExtendedFloat):
212 other = ExtendedFloat(other)
213 diff = self - other
214 return cmp(diff.mantissa, 0)