Kuang-che Wu | 6e4beca | 2018-06-27 17:45:02 +0800 | [diff] [blame] | 1 | # -*- coding: utf-8 -*- |
Kuang-che Wu | 88875db | 2017-07-20 10:47:53 +0800 | [diff] [blame] | 2 | # Copyright 2017 The Chromium OS Authors. All rights reserved. |
| 3 | # Use of this source code is governed by a BSD-style license that can be |
| 4 | # found in the LICENSE file. |
| 5 | """Math utilities.""" |
| 6 | |
| 7 | from __future__ import print_function |
| 8 | |
| 9 | import copy |
Kuang-che Wu | d3ff586 | 2019-11-26 11:41:05 +0800 | [diff] [blame] | 10 | import functools |
Kuang-che Wu | 88875db | 2017-07-20 10:47:53 +0800 | [diff] [blame] | 11 | import math |
| 12 | import sys |
| 13 | |
Kuang-che Wu | a7ddf9b | 2019-11-25 18:59:57 +0800 | [diff] [blame] | 14 | import six |
| 15 | |
Kuang-che Wu | 88875db | 2017-07-20 10:47:53 +0800 | [diff] [blame] | 16 | |
Kuang-che Wu | 328da39 | 2018-09-12 17:20:13 +0800 | [diff] [blame] | 17 | def log(x): |
Kuang-che Wu | 88875db | 2017-07-20 10:47:53 +0800 | [diff] [blame] | 18 | """Wrapper of log(x) handling zero and negative value.""" |
| 19 | # Due to arithmetic error, x may become tiny negative. |
Kuang-che Wu | 328da39 | 2018-09-12 17:20:13 +0800 | [diff] [blame] | 20 | # -1e-10 is not enough, see strategy_test.test_next_idx_arithmetic_error for |
| 21 | # detail. |
| 22 | if 0 >= x > -1e-9: |
Kuang-che Wu | 88875db | 2017-07-20 10:47:53 +0800 | [diff] [blame] | 23 | return 0.0 |
| 24 | return math.log(x) |
| 25 | |
| 26 | |
| 27 | def _xlogx(x): |
Kuang-che Wu | 328da39 | 2018-09-12 17:20:13 +0800 | [diff] [blame] | 28 | return x * log(x) |
Kuang-che Wu | 88875db | 2017-07-20 10:47:53 +0800 | [diff] [blame] | 29 | |
| 30 | |
| 31 | def least(values): |
| 32 | """Finds minimum non-zero value. |
| 33 | |
| 34 | Args: |
| 35 | values: Non-negative values. |
| 36 | |
| 37 | Returns: |
| 38 | The minimum non-zero value. |
| 39 | |
| 40 | Raises: |
| 41 | ValueError: if all values are zero or negative. |
| 42 | """ |
| 43 | return min(x for x in values if x > 0) |
| 44 | |
| 45 | |
Kuang-che Wu | 4f6f912 | 2019-04-23 17:44:46 +0800 | [diff] [blame] | 46 | def average(values): |
| 47 | """Calculates average of values. |
| 48 | |
| 49 | Args: |
| 50 | values: numbers, at least one element |
| 51 | |
| 52 | Raises: |
| 53 | ValueError: if empty |
| 54 | """ |
| 55 | if not values: |
| 56 | raise ValueError('calculate average of empty list') |
| 57 | return float(sum(values)) / len(values) |
| 58 | |
| 59 | |
Kuang-che Wu | 88875db | 2017-07-20 10:47:53 +0800 | [diff] [blame] | 60 | class EntropyStats(object): |
| 61 | r"""A data structure to maintain cross entropy of a set of values. |
| 62 | |
| 63 | This is a math helper for NoisyBinarySearch. Given a set of probability, |
| 64 | this class can incrementally calculate their cross entropy. |
| 65 | |
| 66 | Algorithm: |
| 67 | Given unnormalized p_i, their cross entropy could be expressed as |
| 68 | Let S = \sum_i p_i |
| 69 | Normalized(p): p_i' = p_i/S |
| 70 | CrossEntropy(p) = \sum_i { -(p_i/S) * log (p_i/S) } |
| 71 | = \sum_i { -(p_i/S) (log p_i - log S) } |
| 72 | = \sum_i { -p_i/S * log p_i } + \sum_i { p_i/S * log S } |
| 73 | = -1/S * \sum_i { p_i*log p_i } + log S |
| 74 | = log S - 1/S * \sum_i { p_i*log p_i } |
| 75 | |
| 76 | So we can maintain sum=|S| and sum_log=|\sum_i { p_i*log p_i }| for |
| 77 | incremental update. |
Kuang-che Wu | a7ddf9b | 2019-11-25 18:59:57 +0800 | [diff] [blame] | 78 | """ # pylint: disable=docstring-trailing-quotes |
Kuang-che Wu | 88875db | 2017-07-20 10:47:53 +0800 | [diff] [blame] | 79 | |
| 80 | def __init__(self, p): |
| 81 | """Initializes EntropyStats. |
| 82 | |
| 83 | Args: |
Kuang-che Wu | c90dbf1 | 2018-07-19 16:40:14 +0800 | [diff] [blame] | 84 | p: A list of probability value. Could be unnormalized. |
Kuang-che Wu | 88875db | 2017-07-20 10:47:53 +0800 | [diff] [blame] | 85 | """ |
| 86 | self.sum = sum(p) |
Kuang-che Wu | c89f2a2 | 2019-11-26 15:30:50 +0800 | [diff] [blame] | 87 | self.sum_log = sum(_xlogx(x) for x in p) |
Kuang-che Wu | 88875db | 2017-07-20 10:47:53 +0800 | [diff] [blame] | 88 | |
| 89 | def entropy(self): |
| 90 | """Returns cross entropy.""" |
| 91 | if self.sum == 0: |
| 92 | return 0 |
Kuang-che Wu | 328da39 | 2018-09-12 17:20:13 +0800 | [diff] [blame] | 93 | return log(self.sum) - self.sum_log / self.sum |
Kuang-che Wu | 88875db | 2017-07-20 10:47:53 +0800 | [diff] [blame] | 94 | |
| 95 | def replace(self, old, new): |
| 96 | """Replaces one random variable in the collection with another. |
| 97 | |
| 98 | Args: |
| 99 | old: original value to be replaced. |
| 100 | new: new value |
| 101 | """ |
| 102 | self.sum += new - old |
| 103 | self.sum_log += _xlogx(new) - _xlogx(old) |
| 104 | |
| 105 | def multiply(self, value): |
| 106 | """Multiplies all values in the collection by |value|. |
| 107 | |
| 108 | Returns: |
| 109 | A new instance of EntropyStats with calculated value. |
| 110 | """ |
| 111 | other = copy.copy(self) |
| 112 | # \sum_i { (p_i*value) * log (p_i*value) } |
| 113 | # = \sum_i { (p_i*value) * (log p_i + log value) } |
| 114 | # = value * \sum_i { p_i log p_i } + \sum_i { p_i * value * log value } |
| 115 | # = value * \sum_i { p_i log p_i } + \sum_i { p_i } * value * log value |
| 116 | other.sum_log = value * self.sum_log + self.sum * _xlogx(value) |
| 117 | |
| 118 | other.sum *= value |
| 119 | return other |
| 120 | |
| 121 | |
Kuang-che Wu | d3ff586 | 2019-11-26 11:41:05 +0800 | [diff] [blame] | 122 | @functools.total_ordering |
Kuang-che Wu | 88875db | 2017-07-20 10:47:53 +0800 | [diff] [blame] | 123 | class ExtendedFloat(object): |
| 124 | """Custom floating point number with larger range of exponent. |
| 125 | |
| 126 | Problem: |
| 127 | In the formula of noisy binary search algorithm, we need to calculate p**n |
| 128 | (where n is test count) and normalize probabilities. When n is hundreds or |
| 129 | more, p**n goes too small and becomes zero (underflow), the normalization |
| 130 | step will fail. |
| 131 | |
| 132 | Since n > 100 is not unreasonable large for noisy bisection, we need a |
| 133 | numeric type less likely or never underflow. |
| 134 | |
| 135 | Solution: |
| 136 | Supports a floating number is represented as |
| 137 | x = mantissa * 2.**exponent |
| 138 | We use python's float to represent `mantissa` and python's int to represent |
| 139 | `exponent`. |
| 140 | |
| 141 | For mantissa, we can accept calculated probability is not super accurate |
| 142 | because the worst case is just running test one or two more times. So we |
| 143 | keep mantissa in python's built-in float. |
| 144 | |
| 145 | For exponent, we represent it using python's int, which is virtually |
| 146 | unlimited digits. |
| 147 | |
| 148 | Alternative solutions: |
| 149 | 1. Maybe we can rewrite the formula to avoid underflow. But I want to keep |
| 150 | the arithmetic code simple. |
| 151 | |
| 152 | 2. Use python stdlib's decimal or fractions. But they are very slow -- |
| 153 | decimal is 6-50x slower, fractions is 9-760x slower while ExtendedFloat |
| 154 | is about 1.2-3.3x slower [1]. |
| 155 | |
| 156 | [1] test setup: easy N=1000, oracle=(0.01, 0.2) |
| 157 | hard N=1000, oracle=(0.45, 0.55) |
| 158 | |
| 159 | Attributes: |
| 160 | mantissa: The mantissa part of floating number. Its range is 0.5 <= |
| 161 | abs(mantissa) < 1.0. |
| 162 | exponent: The exponent part of floating number. |
| 163 | """ |
| 164 | |
| 165 | def __init__(self, value, exp=0): |
| 166 | self.mantissa, self.exponent = math.frexp(value) |
| 167 | self.exponent = self.exponent + exp if value else 0 |
| 168 | |
| 169 | def __float__(self): |
| 170 | return math.ldexp(self.mantissa, self.exponent) |
| 171 | |
| 172 | def __neg__(self): |
| 173 | return ExtendedFloat(-self.mantissa, self.exponent) |
| 174 | |
| 175 | def __abs__(self): |
| 176 | return ExtendedFloat(abs(self.mantissa), self.exponent) |
| 177 | |
| 178 | def __add__(self, other): |
| 179 | if not isinstance(other, ExtendedFloat): |
| 180 | other = ExtendedFloat(other) |
| 181 | if self.mantissa == 0: |
| 182 | return other |
| 183 | if self.exponent < other.exponent: |
| 184 | return other.__add__(self) |
| 185 | value = math.ldexp(other.mantissa, other.exponent - self.exponent) |
| 186 | return ExtendedFloat(value + self.mantissa, self.exponent) |
| 187 | |
| 188 | __radd__ = __add__ |
| 189 | |
| 190 | def __pow__(self, p): |
| 191 | # Because 0.5 <= abs(mantissa) < 1.0, using float is enough without |
| 192 | # underflow for small p. |
| 193 | if p < -sys.float_info.min_exp: |
| 194 | return ExtendedFloat(self.mantissa**p, self.exponent * p) |
| 195 | |
| 196 | half_pow = self.__pow__(p >> 1) |
| 197 | result = half_pow.__mul__(half_pow) |
| 198 | if p & 1: |
| 199 | result = result.__mul__(self) |
| 200 | return result |
| 201 | |
| 202 | def __sub__(self, other): |
| 203 | return self.__add__(-other) |
| 204 | |
| 205 | def __rsub__(self, other): |
| 206 | return ExtendedFloat(other).__add__(-self) |
| 207 | |
| 208 | def __mul__(self, other): |
| 209 | if not isinstance(other, ExtendedFloat): |
| 210 | other = ExtendedFloat(other) |
| 211 | return ExtendedFloat(self.mantissa * other.mantissa, |
| 212 | self.exponent + other.exponent) |
| 213 | |
| 214 | __rmul__ = __mul__ |
| 215 | |
Kuang-che Wu | a7ddf9b | 2019-11-25 18:59:57 +0800 | [diff] [blame] | 216 | def __truediv__(self, other): |
Kuang-che Wu | 88875db | 2017-07-20 10:47:53 +0800 | [diff] [blame] | 217 | if not isinstance(other, ExtendedFloat): |
| 218 | other = ExtendedFloat(other) |
| 219 | return ExtendedFloat(self.mantissa / other.mantissa, |
| 220 | self.exponent - other.exponent) |
| 221 | |
Kuang-che Wu | a7ddf9b | 2019-11-25 18:59:57 +0800 | [diff] [blame] | 222 | def __rtruediv__(self, other): |
Kuang-che Wu | d3ff586 | 2019-11-26 11:41:05 +0800 | [diff] [blame] | 223 | return ExtendedFloat(other).__truediv__(self) |
Kuang-che Wu | 88875db | 2017-07-20 10:47:53 +0800 | [diff] [blame] | 224 | |
Kuang-che Wu | a7ddf9b | 2019-11-25 18:59:57 +0800 | [diff] [blame] | 225 | # TODO(kcwu): remove __div__ and __rdiv__ once we migrated to python3 |
| 226 | if six.PY2: |
| 227 | __div__ = __truediv__ |
| 228 | __rdiv__ = __rtruediv__ |
| 229 | |
Kuang-che Wu | 88875db | 2017-07-20 10:47:53 +0800 | [diff] [blame] | 230 | def __repr__(self): |
Kuang-che Wu | ae6824b | 2019-08-27 22:20:01 +0800 | [diff] [blame] | 231 | return 'ExtendedFloat(%f, %d)' % (self.mantissa, self.exponent) |
Kuang-che Wu | 88875db | 2017-07-20 10:47:53 +0800 | [diff] [blame] | 232 | |
Kuang-che Wu | d3ff586 | 2019-11-26 11:41:05 +0800 | [diff] [blame] | 233 | def __eq__(self, other): |
Kuang-che Wu | 88875db | 2017-07-20 10:47:53 +0800 | [diff] [blame] | 234 | if not isinstance(other, ExtendedFloat): |
| 235 | other = ExtendedFloat(other) |
Kuang-che Wu | d3ff586 | 2019-11-26 11:41:05 +0800 | [diff] [blame] | 236 | return self.mantissa == other.mantissa and self.exponent == other.exponent |
| 237 | |
| 238 | def __lt__(self, other): |
| 239 | if not isinstance(other, ExtendedFloat): |
| 240 | other = ExtendedFloat(other) |
| 241 | if self.mantissa * other.mantissa <= 0: |
| 242 | return self.mantissa < other.mantissa |
| 243 | |
| 244 | if self.exponent == other.exponent: |
| 245 | return self.mantissa < other.mantissa |
| 246 | |
| 247 | if self.mantissa > 0: |
| 248 | return self.exponent < other.exponent |
| 249 | return self.exponent > other.exponent |
| 250 | |
| 251 | def __round__(self, ndigits): |
| 252 | return round(float(self), ndigits) |