Kuang-che Wu | 88875db | 2017-07-20 10:47:53 +0800 | [diff] [blame] | 1 | # Copyright 2017 The Chromium OS Authors. All rights reserved. |
| 2 | # Use of this source code is governed by a BSD-style license that can be |
| 3 | # found in the LICENSE file. |
| 4 | """Math utilities.""" |
| 5 | |
| 6 | from __future__ import print_function |
| 7 | |
| 8 | import copy |
| 9 | import math |
| 10 | import sys |
| 11 | |
| 12 | |
| 13 | def _log(x): |
| 14 | """Wrapper of log(x) handling zero and negative value.""" |
| 15 | # Due to arithmetic error, x may become tiny negative. |
| 16 | if abs(x) < 1e-15: |
| 17 | return 0.0 |
| 18 | return math.log(x) |
| 19 | |
| 20 | |
| 21 | def _xlogx(x): |
| 22 | return x * _log(x) |
| 23 | |
| 24 | |
| 25 | def least(values): |
| 26 | """Finds minimum non-zero value. |
| 27 | |
| 28 | Args: |
| 29 | values: Non-negative values. |
| 30 | |
| 31 | Returns: |
| 32 | The minimum non-zero value. |
| 33 | |
| 34 | Raises: |
| 35 | ValueError: if all values are zero or negative. |
| 36 | """ |
| 37 | return min(x for x in values if x > 0) |
| 38 | |
| 39 | |
| 40 | class EntropyStats(object): |
| 41 | r"""A data structure to maintain cross entropy of a set of values. |
| 42 | |
| 43 | This is a math helper for NoisyBinarySearch. Given a set of probability, |
| 44 | this class can incrementally calculate their cross entropy. |
| 45 | |
| 46 | Algorithm: |
| 47 | Given unnormalized p_i, their cross entropy could be expressed as |
| 48 | Let S = \sum_i p_i |
| 49 | Normalized(p): p_i' = p_i/S |
| 50 | CrossEntropy(p) = \sum_i { -(p_i/S) * log (p_i/S) } |
| 51 | = \sum_i { -(p_i/S) (log p_i - log S) } |
| 52 | = \sum_i { -p_i/S * log p_i } + \sum_i { p_i/S * log S } |
| 53 | = -1/S * \sum_i { p_i*log p_i } + log S |
| 54 | = log S - 1/S * \sum_i { p_i*log p_i } |
| 55 | |
| 56 | So we can maintain sum=|S| and sum_log=|\sum_i { p_i*log p_i }| for |
| 57 | incremental update. |
| 58 | """ |
| 59 | |
| 60 | def __init__(self, p): |
| 61 | """Initializes EntropyStats. |
| 62 | |
| 63 | Args: |
| 64 | p: A list of probability value. Could be unnormalized. |
| 65 | """ |
| 66 | self.sum = sum(p) |
| 67 | self.sum_log = sum(map(_xlogx, p)) |
| 68 | |
| 69 | def entropy(self): |
| 70 | """Returns cross entropy.""" |
| 71 | if self.sum == 0: |
| 72 | return 0 |
| 73 | return _log(self.sum) - self.sum_log / self.sum |
| 74 | |
| 75 | def replace(self, old, new): |
| 76 | """Replaces one random variable in the collection with another. |
| 77 | |
| 78 | Args: |
| 79 | old: original value to be replaced. |
| 80 | new: new value |
| 81 | """ |
| 82 | self.sum += new - old |
| 83 | self.sum_log += _xlogx(new) - _xlogx(old) |
| 84 | |
| 85 | def multiply(self, value): |
| 86 | """Multiplies all values in the collection by |value|. |
| 87 | |
| 88 | Returns: |
| 89 | A new instance of EntropyStats with calculated value. |
| 90 | """ |
| 91 | other = copy.copy(self) |
| 92 | # \sum_i { (p_i*value) * log (p_i*value) } |
| 93 | # = \sum_i { (p_i*value) * (log p_i + log value) } |
| 94 | # = value * \sum_i { p_i log p_i } + \sum_i { p_i * value * log value } |
| 95 | # = value * \sum_i { p_i log p_i } + \sum_i { p_i } * value * log value |
| 96 | other.sum_log = value * self.sum_log + self.sum * _xlogx(value) |
| 97 | |
| 98 | other.sum *= value |
| 99 | return other |
| 100 | |
| 101 | |
| 102 | class ExtendedFloat(object): |
| 103 | """Custom floating point number with larger range of exponent. |
| 104 | |
| 105 | Problem: |
| 106 | In the formula of noisy binary search algorithm, we need to calculate p**n |
| 107 | (where n is test count) and normalize probabilities. When n is hundreds or |
| 108 | more, p**n goes too small and becomes zero (underflow), the normalization |
| 109 | step will fail. |
| 110 | |
| 111 | Since n > 100 is not unreasonable large for noisy bisection, we need a |
| 112 | numeric type less likely or never underflow. |
| 113 | |
| 114 | Solution: |
| 115 | Supports a floating number is represented as |
| 116 | x = mantissa * 2.**exponent |
| 117 | We use python's float to represent `mantissa` and python's int to represent |
| 118 | `exponent`. |
| 119 | |
| 120 | For mantissa, we can accept calculated probability is not super accurate |
| 121 | because the worst case is just running test one or two more times. So we |
| 122 | keep mantissa in python's built-in float. |
| 123 | |
| 124 | For exponent, we represent it using python's int, which is virtually |
| 125 | unlimited digits. |
| 126 | |
| 127 | Alternative solutions: |
| 128 | 1. Maybe we can rewrite the formula to avoid underflow. But I want to keep |
| 129 | the arithmetic code simple. |
| 130 | |
| 131 | 2. Use python stdlib's decimal or fractions. But they are very slow -- |
| 132 | decimal is 6-50x slower, fractions is 9-760x slower while ExtendedFloat |
| 133 | is about 1.2-3.3x slower [1]. |
| 134 | |
| 135 | [1] test setup: easy N=1000, oracle=(0.01, 0.2) |
| 136 | hard N=1000, oracle=(0.45, 0.55) |
| 137 | |
| 138 | Attributes: |
| 139 | mantissa: The mantissa part of floating number. Its range is 0.5 <= |
| 140 | abs(mantissa) < 1.0. |
| 141 | exponent: The exponent part of floating number. |
| 142 | """ |
| 143 | |
| 144 | def __init__(self, value, exp=0): |
| 145 | self.mantissa, self.exponent = math.frexp(value) |
| 146 | self.exponent = self.exponent + exp if value else 0 |
| 147 | |
| 148 | def __float__(self): |
| 149 | return math.ldexp(self.mantissa, self.exponent) |
| 150 | |
| 151 | def __neg__(self): |
| 152 | return ExtendedFloat(-self.mantissa, self.exponent) |
| 153 | |
| 154 | def __abs__(self): |
| 155 | return ExtendedFloat(abs(self.mantissa), self.exponent) |
| 156 | |
| 157 | def __add__(self, other): |
| 158 | if not isinstance(other, ExtendedFloat): |
| 159 | other = ExtendedFloat(other) |
| 160 | if self.mantissa == 0: |
| 161 | return other |
| 162 | if self.exponent < other.exponent: |
| 163 | return other.__add__(self) |
| 164 | value = math.ldexp(other.mantissa, other.exponent - self.exponent) |
| 165 | return ExtendedFloat(value + self.mantissa, self.exponent) |
| 166 | |
| 167 | __radd__ = __add__ |
| 168 | |
| 169 | def __pow__(self, p): |
| 170 | # Because 0.5 <= abs(mantissa) < 1.0, using float is enough without |
| 171 | # underflow for small p. |
| 172 | if p < -sys.float_info.min_exp: |
| 173 | return ExtendedFloat(self.mantissa**p, self.exponent * p) |
| 174 | |
| 175 | half_pow = self.__pow__(p >> 1) |
| 176 | result = half_pow.__mul__(half_pow) |
| 177 | if p & 1: |
| 178 | result = result.__mul__(self) |
| 179 | return result |
| 180 | |
| 181 | def __sub__(self, other): |
| 182 | return self.__add__(-other) |
| 183 | |
| 184 | def __rsub__(self, other): |
| 185 | return ExtendedFloat(other).__add__(-self) |
| 186 | |
| 187 | def __mul__(self, other): |
| 188 | if not isinstance(other, ExtendedFloat): |
| 189 | other = ExtendedFloat(other) |
| 190 | return ExtendedFloat(self.mantissa * other.mantissa, |
| 191 | self.exponent + other.exponent) |
| 192 | |
| 193 | __rmul__ = __mul__ |
| 194 | |
| 195 | def __div__(self, other): |
| 196 | if not isinstance(other, ExtendedFloat): |
| 197 | other = ExtendedFloat(other) |
| 198 | return ExtendedFloat(self.mantissa / other.mantissa, |
| 199 | self.exponent - other.exponent) |
| 200 | |
| 201 | def __rdiv__(self, other): |
| 202 | return ExtendedFloat(other).__div__(self) |
| 203 | |
| 204 | def __repr__(self): |
| 205 | return "ExtendedFloat(%f, %d)" % (self.mantissa, self.exponent) |
| 206 | |
| 207 | def __cmp__(self, other): |
| 208 | if not isinstance(other, ExtendedFloat): |
| 209 | other = ExtendedFloat(other) |
| 210 | diff = self - other |
| 211 | return cmp(diff.mantissa, 0) |