blob: 90cba92ff6875f7c56564758183e66b32ad20017 [file] [log] [blame]
Kuang-che Wu6e4beca2018-06-27 17:45:02 +08001# -*- coding: utf-8 -*-
Kuang-che Wu88875db2017-07-20 10:47:53 +08002# Copyright 2017 The Chromium OS Authors. All rights reserved.
3# Use of this source code is governed by a BSD-style license that can be
4# found in the LICENSE file.
5"""Math utilities."""
6
7from __future__ import print_function
8
9import copy
10import math
11import sys
12
13
14def _log(x):
15 """Wrapper of log(x) handling zero and negative value."""
16 # Due to arithmetic error, x may become tiny negative.
17 if abs(x) < 1e-15:
18 return 0.0
19 return math.log(x)
20
21
22def _xlogx(x):
23 return x * _log(x)
24
25
26def least(values):
27 """Finds minimum non-zero value.
28
29 Args:
30 values: Non-negative values.
31
32 Returns:
33 The minimum non-zero value.
34
35 Raises:
36 ValueError: if all values are zero or negative.
37 """
38 return min(x for x in values if x > 0)
39
40
41class EntropyStats(object):
42 r"""A data structure to maintain cross entropy of a set of values.
43
44 This is a math helper for NoisyBinarySearch. Given a set of probability,
45 this class can incrementally calculate their cross entropy.
46
47 Algorithm:
48 Given unnormalized p_i, their cross entropy could be expressed as
49 Let S = \sum_i p_i
50 Normalized(p): p_i' = p_i/S
51 CrossEntropy(p) = \sum_i { -(p_i/S) * log (p_i/S) }
52 = \sum_i { -(p_i/S) (log p_i - log S) }
53 = \sum_i { -p_i/S * log p_i } + \sum_i { p_i/S * log S }
54 = -1/S * \sum_i { p_i*log p_i } + log S
55 = log S - 1/S * \sum_i { p_i*log p_i }
56
57 So we can maintain sum=|S| and sum_log=|\sum_i { p_i*log p_i }| for
58 incremental update.
59 """
60
61 def __init__(self, p):
62 """Initializes EntropyStats.
63
64 Args:
65 p: A list of probability value. Could be unnormalized.
66 """
67 self.sum = sum(p)
68 self.sum_log = sum(map(_xlogx, p))
69
70 def entropy(self):
71 """Returns cross entropy."""
72 if self.sum == 0:
73 return 0
74 return _log(self.sum) - self.sum_log / self.sum
75
76 def replace(self, old, new):
77 """Replaces one random variable in the collection with another.
78
79 Args:
80 old: original value to be replaced.
81 new: new value
82 """
83 self.sum += new - old
84 self.sum_log += _xlogx(new) - _xlogx(old)
85
86 def multiply(self, value):
87 """Multiplies all values in the collection by |value|.
88
89 Returns:
90 A new instance of EntropyStats with calculated value.
91 """
92 other = copy.copy(self)
93 # \sum_i { (p_i*value) * log (p_i*value) }
94 # = \sum_i { (p_i*value) * (log p_i + log value) }
95 # = value * \sum_i { p_i log p_i } + \sum_i { p_i * value * log value }
96 # = value * \sum_i { p_i log p_i } + \sum_i { p_i } * value * log value
97 other.sum_log = value * self.sum_log + self.sum * _xlogx(value)
98
99 other.sum *= value
100 return other
101
102
103class ExtendedFloat(object):
104 """Custom floating point number with larger range of exponent.
105
106 Problem:
107 In the formula of noisy binary search algorithm, we need to calculate p**n
108 (where n is test count) and normalize probabilities. When n is hundreds or
109 more, p**n goes too small and becomes zero (underflow), the normalization
110 step will fail.
111
112 Since n > 100 is not unreasonable large for noisy bisection, we need a
113 numeric type less likely or never underflow.
114
115 Solution:
116 Supports a floating number is represented as
117 x = mantissa * 2.**exponent
118 We use python's float to represent `mantissa` and python's int to represent
119 `exponent`.
120
121 For mantissa, we can accept calculated probability is not super accurate
122 because the worst case is just running test one or two more times. So we
123 keep mantissa in python's built-in float.
124
125 For exponent, we represent it using python's int, which is virtually
126 unlimited digits.
127
128 Alternative solutions:
129 1. Maybe we can rewrite the formula to avoid underflow. But I want to keep
130 the arithmetic code simple.
131
132 2. Use python stdlib's decimal or fractions. But they are very slow --
133 decimal is 6-50x slower, fractions is 9-760x slower while ExtendedFloat
134 is about 1.2-3.3x slower [1].
135
136 [1] test setup: easy N=1000, oracle=(0.01, 0.2)
137 hard N=1000, oracle=(0.45, 0.55)
138
139 Attributes:
140 mantissa: The mantissa part of floating number. Its range is 0.5 <=
141 abs(mantissa) < 1.0.
142 exponent: The exponent part of floating number.
143 """
144
145 def __init__(self, value, exp=0):
146 self.mantissa, self.exponent = math.frexp(value)
147 self.exponent = self.exponent + exp if value else 0
148
149 def __float__(self):
150 return math.ldexp(self.mantissa, self.exponent)
151
152 def __neg__(self):
153 return ExtendedFloat(-self.mantissa, self.exponent)
154
155 def __abs__(self):
156 return ExtendedFloat(abs(self.mantissa), self.exponent)
157
158 def __add__(self, other):
159 if not isinstance(other, ExtendedFloat):
160 other = ExtendedFloat(other)
161 if self.mantissa == 0:
162 return other
163 if self.exponent < other.exponent:
164 return other.__add__(self)
165 value = math.ldexp(other.mantissa, other.exponent - self.exponent)
166 return ExtendedFloat(value + self.mantissa, self.exponent)
167
168 __radd__ = __add__
169
170 def __pow__(self, p):
171 # Because 0.5 <= abs(mantissa) < 1.0, using float is enough without
172 # underflow for small p.
173 if p < -sys.float_info.min_exp:
174 return ExtendedFloat(self.mantissa**p, self.exponent * p)
175
176 half_pow = self.__pow__(p >> 1)
177 result = half_pow.__mul__(half_pow)
178 if p & 1:
179 result = result.__mul__(self)
180 return result
181
182 def __sub__(self, other):
183 return self.__add__(-other)
184
185 def __rsub__(self, other):
186 return ExtendedFloat(other).__add__(-self)
187
188 def __mul__(self, other):
189 if not isinstance(other, ExtendedFloat):
190 other = ExtendedFloat(other)
191 return ExtendedFloat(self.mantissa * other.mantissa,
192 self.exponent + other.exponent)
193
194 __rmul__ = __mul__
195
196 def __div__(self, other):
197 if not isinstance(other, ExtendedFloat):
198 other = ExtendedFloat(other)
199 return ExtendedFloat(self.mantissa / other.mantissa,
200 self.exponent - other.exponent)
201
202 def __rdiv__(self, other):
203 return ExtendedFloat(other).__div__(self)
204
205 def __repr__(self):
206 return "ExtendedFloat(%f, %d)" % (self.mantissa, self.exponent)
207
208 def __cmp__(self, other):
209 if not isinstance(other, ExtendedFloat):
210 other = ExtendedFloat(other)
211 diff = self - other
212 return cmp(diff.mantissa, 0)