Kuang-che Wu | 6e4beca | 2018-06-27 17:45:02 +0800 | [diff] [blame] | 1 | # -*- coding: utf-8 -*- |
Kuang-che Wu | 88875db | 2017-07-20 10:47:53 +0800 | [diff] [blame] | 2 | # Copyright 2017 The Chromium OS Authors. All rights reserved. |
| 3 | # Use of this source code is governed by a BSD-style license that can be |
| 4 | # found in the LICENSE file. |
| 5 | """Math utilities.""" |
| 6 | |
| 7 | from __future__ import print_function |
| 8 | |
| 9 | import copy |
| 10 | import math |
| 11 | import sys |
| 12 | |
Kuang-che Wu | a7ddf9b | 2019-11-25 18:59:57 +0800 | [diff] [blame^] | 13 | import six |
| 14 | |
Kuang-che Wu | 88875db | 2017-07-20 10:47:53 +0800 | [diff] [blame] | 15 | |
Kuang-che Wu | 328da39 | 2018-09-12 17:20:13 +0800 | [diff] [blame] | 16 | def log(x): |
Kuang-che Wu | 88875db | 2017-07-20 10:47:53 +0800 | [diff] [blame] | 17 | """Wrapper of log(x) handling zero and negative value.""" |
| 18 | # Due to arithmetic error, x may become tiny negative. |
Kuang-che Wu | 328da39 | 2018-09-12 17:20:13 +0800 | [diff] [blame] | 19 | # -1e-10 is not enough, see strategy_test.test_next_idx_arithmetic_error for |
| 20 | # detail. |
| 21 | if 0 >= x > -1e-9: |
Kuang-che Wu | 88875db | 2017-07-20 10:47:53 +0800 | [diff] [blame] | 22 | return 0.0 |
| 23 | return math.log(x) |
| 24 | |
| 25 | |
| 26 | def _xlogx(x): |
Kuang-che Wu | 328da39 | 2018-09-12 17:20:13 +0800 | [diff] [blame] | 27 | return x * log(x) |
Kuang-che Wu | 88875db | 2017-07-20 10:47:53 +0800 | [diff] [blame] | 28 | |
| 29 | |
| 30 | def least(values): |
| 31 | """Finds minimum non-zero value. |
| 32 | |
| 33 | Args: |
| 34 | values: Non-negative values. |
| 35 | |
| 36 | Returns: |
| 37 | The minimum non-zero value. |
| 38 | |
| 39 | Raises: |
| 40 | ValueError: if all values are zero or negative. |
| 41 | """ |
| 42 | return min(x for x in values if x > 0) |
| 43 | |
| 44 | |
Kuang-che Wu | 4f6f912 | 2019-04-23 17:44:46 +0800 | [diff] [blame] | 45 | def average(values): |
| 46 | """Calculates average of values. |
| 47 | |
| 48 | Args: |
| 49 | values: numbers, at least one element |
| 50 | |
| 51 | Raises: |
| 52 | ValueError: if empty |
| 53 | """ |
| 54 | if not values: |
| 55 | raise ValueError('calculate average of empty list') |
| 56 | return float(sum(values)) / len(values) |
| 57 | |
| 58 | |
Kuang-che Wu | 88875db | 2017-07-20 10:47:53 +0800 | [diff] [blame] | 59 | class EntropyStats(object): |
| 60 | r"""A data structure to maintain cross entropy of a set of values. |
| 61 | |
| 62 | This is a math helper for NoisyBinarySearch. Given a set of probability, |
| 63 | this class can incrementally calculate their cross entropy. |
| 64 | |
| 65 | Algorithm: |
| 66 | Given unnormalized p_i, their cross entropy could be expressed as |
| 67 | Let S = \sum_i p_i |
| 68 | Normalized(p): p_i' = p_i/S |
| 69 | CrossEntropy(p) = \sum_i { -(p_i/S) * log (p_i/S) } |
| 70 | = \sum_i { -(p_i/S) (log p_i - log S) } |
| 71 | = \sum_i { -p_i/S * log p_i } + \sum_i { p_i/S * log S } |
| 72 | = -1/S * \sum_i { p_i*log p_i } + log S |
| 73 | = log S - 1/S * \sum_i { p_i*log p_i } |
| 74 | |
| 75 | So we can maintain sum=|S| and sum_log=|\sum_i { p_i*log p_i }| for |
| 76 | incremental update. |
Kuang-che Wu | a7ddf9b | 2019-11-25 18:59:57 +0800 | [diff] [blame^] | 77 | """ # pylint: disable=docstring-trailing-quotes |
Kuang-che Wu | 88875db | 2017-07-20 10:47:53 +0800 | [diff] [blame] | 78 | |
| 79 | def __init__(self, p): |
| 80 | """Initializes EntropyStats. |
| 81 | |
| 82 | Args: |
Kuang-che Wu | c90dbf1 | 2018-07-19 16:40:14 +0800 | [diff] [blame] | 83 | p: A list of probability value. Could be unnormalized. |
Kuang-che Wu | 88875db | 2017-07-20 10:47:53 +0800 | [diff] [blame] | 84 | """ |
| 85 | self.sum = sum(p) |
| 86 | self.sum_log = sum(map(_xlogx, p)) |
| 87 | |
| 88 | def entropy(self): |
| 89 | """Returns cross entropy.""" |
| 90 | if self.sum == 0: |
| 91 | return 0 |
Kuang-che Wu | 328da39 | 2018-09-12 17:20:13 +0800 | [diff] [blame] | 92 | return log(self.sum) - self.sum_log / self.sum |
Kuang-che Wu | 88875db | 2017-07-20 10:47:53 +0800 | [diff] [blame] | 93 | |
| 94 | def replace(self, old, new): |
| 95 | """Replaces one random variable in the collection with another. |
| 96 | |
| 97 | Args: |
| 98 | old: original value to be replaced. |
| 99 | new: new value |
| 100 | """ |
| 101 | self.sum += new - old |
| 102 | self.sum_log += _xlogx(new) - _xlogx(old) |
| 103 | |
| 104 | def multiply(self, value): |
| 105 | """Multiplies all values in the collection by |value|. |
| 106 | |
| 107 | Returns: |
| 108 | A new instance of EntropyStats with calculated value. |
| 109 | """ |
| 110 | other = copy.copy(self) |
| 111 | # \sum_i { (p_i*value) * log (p_i*value) } |
| 112 | # = \sum_i { (p_i*value) * (log p_i + log value) } |
| 113 | # = value * \sum_i { p_i log p_i } + \sum_i { p_i * value * log value } |
| 114 | # = value * \sum_i { p_i log p_i } + \sum_i { p_i } * value * log value |
| 115 | other.sum_log = value * self.sum_log + self.sum * _xlogx(value) |
| 116 | |
| 117 | other.sum *= value |
| 118 | return other |
| 119 | |
| 120 | |
| 121 | class ExtendedFloat(object): |
| 122 | """Custom floating point number with larger range of exponent. |
| 123 | |
| 124 | Problem: |
| 125 | In the formula of noisy binary search algorithm, we need to calculate p**n |
| 126 | (where n is test count) and normalize probabilities. When n is hundreds or |
| 127 | more, p**n goes too small and becomes zero (underflow), the normalization |
| 128 | step will fail. |
| 129 | |
| 130 | Since n > 100 is not unreasonable large for noisy bisection, we need a |
| 131 | numeric type less likely or never underflow. |
| 132 | |
| 133 | Solution: |
| 134 | Supports a floating number is represented as |
| 135 | x = mantissa * 2.**exponent |
| 136 | We use python's float to represent `mantissa` and python's int to represent |
| 137 | `exponent`. |
| 138 | |
| 139 | For mantissa, we can accept calculated probability is not super accurate |
| 140 | because the worst case is just running test one or two more times. So we |
| 141 | keep mantissa in python's built-in float. |
| 142 | |
| 143 | For exponent, we represent it using python's int, which is virtually |
| 144 | unlimited digits. |
| 145 | |
| 146 | Alternative solutions: |
| 147 | 1. Maybe we can rewrite the formula to avoid underflow. But I want to keep |
| 148 | the arithmetic code simple. |
| 149 | |
| 150 | 2. Use python stdlib's decimal or fractions. But they are very slow -- |
| 151 | decimal is 6-50x slower, fractions is 9-760x slower while ExtendedFloat |
| 152 | is about 1.2-3.3x slower [1]. |
| 153 | |
| 154 | [1] test setup: easy N=1000, oracle=(0.01, 0.2) |
| 155 | hard N=1000, oracle=(0.45, 0.55) |
| 156 | |
| 157 | Attributes: |
| 158 | mantissa: The mantissa part of floating number. Its range is 0.5 <= |
| 159 | abs(mantissa) < 1.0. |
| 160 | exponent: The exponent part of floating number. |
| 161 | """ |
| 162 | |
| 163 | def __init__(self, value, exp=0): |
| 164 | self.mantissa, self.exponent = math.frexp(value) |
| 165 | self.exponent = self.exponent + exp if value else 0 |
| 166 | |
| 167 | def __float__(self): |
| 168 | return math.ldexp(self.mantissa, self.exponent) |
| 169 | |
| 170 | def __neg__(self): |
| 171 | return ExtendedFloat(-self.mantissa, self.exponent) |
| 172 | |
| 173 | def __abs__(self): |
| 174 | return ExtendedFloat(abs(self.mantissa), self.exponent) |
| 175 | |
| 176 | def __add__(self, other): |
| 177 | if not isinstance(other, ExtendedFloat): |
| 178 | other = ExtendedFloat(other) |
| 179 | if self.mantissa == 0: |
| 180 | return other |
| 181 | if self.exponent < other.exponent: |
| 182 | return other.__add__(self) |
| 183 | value = math.ldexp(other.mantissa, other.exponent - self.exponent) |
| 184 | return ExtendedFloat(value + self.mantissa, self.exponent) |
| 185 | |
| 186 | __radd__ = __add__ |
| 187 | |
| 188 | def __pow__(self, p): |
| 189 | # Because 0.5 <= abs(mantissa) < 1.0, using float is enough without |
| 190 | # underflow for small p. |
| 191 | if p < -sys.float_info.min_exp: |
| 192 | return ExtendedFloat(self.mantissa**p, self.exponent * p) |
| 193 | |
| 194 | half_pow = self.__pow__(p >> 1) |
| 195 | result = half_pow.__mul__(half_pow) |
| 196 | if p & 1: |
| 197 | result = result.__mul__(self) |
| 198 | return result |
| 199 | |
| 200 | def __sub__(self, other): |
| 201 | return self.__add__(-other) |
| 202 | |
| 203 | def __rsub__(self, other): |
| 204 | return ExtendedFloat(other).__add__(-self) |
| 205 | |
| 206 | def __mul__(self, other): |
| 207 | if not isinstance(other, ExtendedFloat): |
| 208 | other = ExtendedFloat(other) |
| 209 | return ExtendedFloat(self.mantissa * other.mantissa, |
| 210 | self.exponent + other.exponent) |
| 211 | |
| 212 | __rmul__ = __mul__ |
| 213 | |
Kuang-che Wu | a7ddf9b | 2019-11-25 18:59:57 +0800 | [diff] [blame^] | 214 | def __truediv__(self, other): |
Kuang-che Wu | 88875db | 2017-07-20 10:47:53 +0800 | [diff] [blame] | 215 | if not isinstance(other, ExtendedFloat): |
| 216 | other = ExtendedFloat(other) |
| 217 | return ExtendedFloat(self.mantissa / other.mantissa, |
| 218 | self.exponent - other.exponent) |
| 219 | |
Kuang-che Wu | a7ddf9b | 2019-11-25 18:59:57 +0800 | [diff] [blame^] | 220 | def __rtruediv__(self, other): |
Kuang-che Wu | 88875db | 2017-07-20 10:47:53 +0800 | [diff] [blame] | 221 | return ExtendedFloat(other).__div__(self) |
| 222 | |
Kuang-che Wu | a7ddf9b | 2019-11-25 18:59:57 +0800 | [diff] [blame^] | 223 | # TODO(kcwu): remove __div__ and __rdiv__ once we migrated to python3 |
| 224 | if six.PY2: |
| 225 | __div__ = __truediv__ |
| 226 | __rdiv__ = __rtruediv__ |
| 227 | |
Kuang-che Wu | 88875db | 2017-07-20 10:47:53 +0800 | [diff] [blame] | 228 | def __repr__(self): |
Kuang-che Wu | ae6824b | 2019-08-27 22:20:01 +0800 | [diff] [blame] | 229 | return 'ExtendedFloat(%f, %d)' % (self.mantissa, self.exponent) |
Kuang-che Wu | 88875db | 2017-07-20 10:47:53 +0800 | [diff] [blame] | 230 | |
| 231 | def __cmp__(self, other): |
| 232 | if not isinstance(other, ExtendedFloat): |
| 233 | other = ExtendedFloat(other) |
| 234 | diff = self - other |
| 235 | return cmp(diff.mantissa, 0) |