Kuang-che Wu | 6e4beca | 2018-06-27 17:45:02 +0800 | [diff] [blame] | 1 | # -*- coding: utf-8 -*- |
Kuang-che Wu | 88875db | 2017-07-20 10:47:53 +0800 | [diff] [blame] | 2 | # Copyright 2017 The Chromium OS Authors. All rights reserved. |
| 3 | # Use of this source code is governed by a BSD-style license that can be |
| 4 | # found in the LICENSE file. |
| 5 | """Math utilities.""" |
| 6 | |
| 7 | from __future__ import print_function |
| 8 | |
| 9 | import copy |
Kuang-che Wu | d3ff586 | 2019-11-26 11:41:05 +0800 | [diff] [blame] | 10 | import functools |
Kuang-che Wu | 88875db | 2017-07-20 10:47:53 +0800 | [diff] [blame] | 11 | import math |
| 12 | import sys |
| 13 | |
| 14 | |
Kuang-che Wu | 328da39 | 2018-09-12 17:20:13 +0800 | [diff] [blame] | 15 | def log(x): |
Kuang-che Wu | 88875db | 2017-07-20 10:47:53 +0800 | [diff] [blame] | 16 | """Wrapper of log(x) handling zero and negative value.""" |
| 17 | # Due to arithmetic error, x may become tiny negative. |
Kuang-che Wu | 328da39 | 2018-09-12 17:20:13 +0800 | [diff] [blame] | 18 | # -1e-10 is not enough, see strategy_test.test_next_idx_arithmetic_error for |
| 19 | # detail. |
| 20 | if 0 >= x > -1e-9: |
Kuang-che Wu | 88875db | 2017-07-20 10:47:53 +0800 | [diff] [blame] | 21 | return 0.0 |
| 22 | return math.log(x) |
| 23 | |
| 24 | |
| 25 | def _xlogx(x): |
Kuang-che Wu | 328da39 | 2018-09-12 17:20:13 +0800 | [diff] [blame] | 26 | return x * log(x) |
Kuang-che Wu | 88875db | 2017-07-20 10:47:53 +0800 | [diff] [blame] | 27 | |
| 28 | |
| 29 | def least(values): |
| 30 | """Finds minimum non-zero value. |
| 31 | |
| 32 | Args: |
| 33 | values: Non-negative values. |
| 34 | |
| 35 | Returns: |
| 36 | The minimum non-zero value. |
| 37 | |
| 38 | Raises: |
| 39 | ValueError: if all values are zero or negative. |
| 40 | """ |
| 41 | return min(x for x in values if x > 0) |
| 42 | |
| 43 | |
Kuang-che Wu | 4f6f912 | 2019-04-23 17:44:46 +0800 | [diff] [blame] | 44 | def average(values): |
| 45 | """Calculates average of values. |
| 46 | |
| 47 | Args: |
| 48 | values: numbers, at least one element |
| 49 | |
| 50 | Raises: |
| 51 | ValueError: if empty |
| 52 | """ |
| 53 | if not values: |
| 54 | raise ValueError('calculate average of empty list') |
| 55 | return float(sum(values)) / len(values) |
| 56 | |
| 57 | |
Kuang-che Wu | 0dc7fc4 | 2020-12-11 22:15:49 +0800 | [diff] [blame^] | 58 | class Averager: |
| 59 | """Helper class to calculate average.""" |
| 60 | |
| 61 | def __init__(self): |
| 62 | self.count = 0 |
| 63 | self.total = 0 |
| 64 | |
| 65 | def add(self, value): |
| 66 | self.count += 1 |
| 67 | self.total += value |
| 68 | |
| 69 | def average(self): |
| 70 | if self.count == 0: |
| 71 | raise ValueError('calculate average of empty list') |
| 72 | return self.total / self.count |
| 73 | |
| 74 | |
Kuang-che Wu | 23192ad | 2020-03-11 18:12:46 +0800 | [diff] [blame] | 75 | class EntropyStats: |
Kuang-che Wu | 88875db | 2017-07-20 10:47:53 +0800 | [diff] [blame] | 76 | r"""A data structure to maintain cross entropy of a set of values. |
| 77 | |
| 78 | This is a math helper for NoisyBinarySearch. Given a set of probability, |
| 79 | this class can incrementally calculate their cross entropy. |
| 80 | |
| 81 | Algorithm: |
| 82 | Given unnormalized p_i, their cross entropy could be expressed as |
| 83 | Let S = \sum_i p_i |
| 84 | Normalized(p): p_i' = p_i/S |
| 85 | CrossEntropy(p) = \sum_i { -(p_i/S) * log (p_i/S) } |
| 86 | = \sum_i { -(p_i/S) (log p_i - log S) } |
| 87 | = \sum_i { -p_i/S * log p_i } + \sum_i { p_i/S * log S } |
| 88 | = -1/S * \sum_i { p_i*log p_i } + log S |
| 89 | = log S - 1/S * \sum_i { p_i*log p_i } |
| 90 | |
| 91 | So we can maintain sum=|S| and sum_log=|\sum_i { p_i*log p_i }| for |
| 92 | incremental update. |
Kuang-che Wu | a7ddf9b | 2019-11-25 18:59:57 +0800 | [diff] [blame] | 93 | """ # pylint: disable=docstring-trailing-quotes |
Kuang-che Wu | 88875db | 2017-07-20 10:47:53 +0800 | [diff] [blame] | 94 | |
| 95 | def __init__(self, p): |
| 96 | """Initializes EntropyStats. |
| 97 | |
| 98 | Args: |
Kuang-che Wu | c90dbf1 | 2018-07-19 16:40:14 +0800 | [diff] [blame] | 99 | p: A list of probability value. Could be unnormalized. |
Kuang-che Wu | 88875db | 2017-07-20 10:47:53 +0800 | [diff] [blame] | 100 | """ |
| 101 | self.sum = sum(p) |
Kuang-che Wu | c89f2a2 | 2019-11-26 15:30:50 +0800 | [diff] [blame] | 102 | self.sum_log = sum(_xlogx(x) for x in p) |
Kuang-che Wu | 88875db | 2017-07-20 10:47:53 +0800 | [diff] [blame] | 103 | |
| 104 | def entropy(self): |
| 105 | """Returns cross entropy.""" |
| 106 | if self.sum == 0: |
| 107 | return 0 |
Kuang-che Wu | 328da39 | 2018-09-12 17:20:13 +0800 | [diff] [blame] | 108 | return log(self.sum) - self.sum_log / self.sum |
Kuang-che Wu | 88875db | 2017-07-20 10:47:53 +0800 | [diff] [blame] | 109 | |
| 110 | def replace(self, old, new): |
| 111 | """Replaces one random variable in the collection with another. |
| 112 | |
| 113 | Args: |
| 114 | old: original value to be replaced. |
| 115 | new: new value |
| 116 | """ |
| 117 | self.sum += new - old |
| 118 | self.sum_log += _xlogx(new) - _xlogx(old) |
| 119 | |
| 120 | def multiply(self, value): |
| 121 | """Multiplies all values in the collection by |value|. |
| 122 | |
| 123 | Returns: |
| 124 | A new instance of EntropyStats with calculated value. |
| 125 | """ |
| 126 | other = copy.copy(self) |
| 127 | # \sum_i { (p_i*value) * log (p_i*value) } |
| 128 | # = \sum_i { (p_i*value) * (log p_i + log value) } |
| 129 | # = value * \sum_i { p_i log p_i } + \sum_i { p_i * value * log value } |
| 130 | # = value * \sum_i { p_i log p_i } + \sum_i { p_i } * value * log value |
| 131 | other.sum_log = value * self.sum_log + self.sum * _xlogx(value) |
| 132 | |
| 133 | other.sum *= value |
| 134 | return other |
| 135 | |
| 136 | |
Kuang-che Wu | d3ff586 | 2019-11-26 11:41:05 +0800 | [diff] [blame] | 137 | @functools.total_ordering |
Kuang-che Wu | 23192ad | 2020-03-11 18:12:46 +0800 | [diff] [blame] | 138 | class ExtendedFloat: |
Kuang-che Wu | 88875db | 2017-07-20 10:47:53 +0800 | [diff] [blame] | 139 | """Custom floating point number with larger range of exponent. |
| 140 | |
| 141 | Problem: |
| 142 | In the formula of noisy binary search algorithm, we need to calculate p**n |
| 143 | (where n is test count) and normalize probabilities. When n is hundreds or |
| 144 | more, p**n goes too small and becomes zero (underflow), the normalization |
| 145 | step will fail. |
| 146 | |
| 147 | Since n > 100 is not unreasonable large for noisy bisection, we need a |
| 148 | numeric type less likely or never underflow. |
| 149 | |
| 150 | Solution: |
| 151 | Supports a floating number is represented as |
| 152 | x = mantissa * 2.**exponent |
| 153 | We use python's float to represent `mantissa` and python's int to represent |
| 154 | `exponent`. |
| 155 | |
| 156 | For mantissa, we can accept calculated probability is not super accurate |
| 157 | because the worst case is just running test one or two more times. So we |
| 158 | keep mantissa in python's built-in float. |
| 159 | |
| 160 | For exponent, we represent it using python's int, which is virtually |
| 161 | unlimited digits. |
| 162 | |
| 163 | Alternative solutions: |
| 164 | 1. Maybe we can rewrite the formula to avoid underflow. But I want to keep |
| 165 | the arithmetic code simple. |
| 166 | |
| 167 | 2. Use python stdlib's decimal or fractions. But they are very slow -- |
| 168 | decimal is 6-50x slower, fractions is 9-760x slower while ExtendedFloat |
| 169 | is about 1.2-3.3x slower [1]. |
| 170 | |
| 171 | [1] test setup: easy N=1000, oracle=(0.01, 0.2) |
| 172 | hard N=1000, oracle=(0.45, 0.55) |
| 173 | |
| 174 | Attributes: |
| 175 | mantissa: The mantissa part of floating number. Its range is 0.5 <= |
| 176 | abs(mantissa) < 1.0. |
| 177 | exponent: The exponent part of floating number. |
| 178 | """ |
| 179 | |
| 180 | def __init__(self, value, exp=0): |
| 181 | self.mantissa, self.exponent = math.frexp(value) |
| 182 | self.exponent = self.exponent + exp if value else 0 |
| 183 | |
| 184 | def __float__(self): |
| 185 | return math.ldexp(self.mantissa, self.exponent) |
| 186 | |
| 187 | def __neg__(self): |
| 188 | return ExtendedFloat(-self.mantissa, self.exponent) |
| 189 | |
| 190 | def __abs__(self): |
| 191 | return ExtendedFloat(abs(self.mantissa), self.exponent) |
| 192 | |
| 193 | def __add__(self, other): |
| 194 | if not isinstance(other, ExtendedFloat): |
| 195 | other = ExtendedFloat(other) |
| 196 | if self.mantissa == 0: |
| 197 | return other |
| 198 | if self.exponent < other.exponent: |
| 199 | return other.__add__(self) |
| 200 | value = math.ldexp(other.mantissa, other.exponent - self.exponent) |
| 201 | return ExtendedFloat(value + self.mantissa, self.exponent) |
| 202 | |
| 203 | __radd__ = __add__ |
| 204 | |
| 205 | def __pow__(self, p): |
| 206 | # Because 0.5 <= abs(mantissa) < 1.0, using float is enough without |
| 207 | # underflow for small p. |
| 208 | if p < -sys.float_info.min_exp: |
| 209 | return ExtendedFloat(self.mantissa**p, self.exponent * p) |
| 210 | |
| 211 | half_pow = self.__pow__(p >> 1) |
| 212 | result = half_pow.__mul__(half_pow) |
| 213 | if p & 1: |
| 214 | result = result.__mul__(self) |
| 215 | return result |
| 216 | |
| 217 | def __sub__(self, other): |
| 218 | return self.__add__(-other) |
| 219 | |
| 220 | def __rsub__(self, other): |
| 221 | return ExtendedFloat(other).__add__(-self) |
| 222 | |
| 223 | def __mul__(self, other): |
| 224 | if not isinstance(other, ExtendedFloat): |
| 225 | other = ExtendedFloat(other) |
| 226 | return ExtendedFloat(self.mantissa * other.mantissa, |
| 227 | self.exponent + other.exponent) |
| 228 | |
| 229 | __rmul__ = __mul__ |
| 230 | |
Kuang-che Wu | a7ddf9b | 2019-11-25 18:59:57 +0800 | [diff] [blame] | 231 | def __truediv__(self, other): |
Kuang-che Wu | 88875db | 2017-07-20 10:47:53 +0800 | [diff] [blame] | 232 | if not isinstance(other, ExtendedFloat): |
| 233 | other = ExtendedFloat(other) |
| 234 | return ExtendedFloat(self.mantissa / other.mantissa, |
| 235 | self.exponent - other.exponent) |
| 236 | |
Kuang-che Wu | a7ddf9b | 2019-11-25 18:59:57 +0800 | [diff] [blame] | 237 | def __rtruediv__(self, other): |
Kuang-che Wu | d3ff586 | 2019-11-26 11:41:05 +0800 | [diff] [blame] | 238 | return ExtendedFloat(other).__truediv__(self) |
Kuang-che Wu | 88875db | 2017-07-20 10:47:53 +0800 | [diff] [blame] | 239 | |
| 240 | def __repr__(self): |
Kuang-che Wu | ae6824b | 2019-08-27 22:20:01 +0800 | [diff] [blame] | 241 | return 'ExtendedFloat(%f, %d)' % (self.mantissa, self.exponent) |
Kuang-che Wu | 88875db | 2017-07-20 10:47:53 +0800 | [diff] [blame] | 242 | |
Kuang-che Wu | d3ff586 | 2019-11-26 11:41:05 +0800 | [diff] [blame] | 243 | def __eq__(self, other): |
Kuang-che Wu | 88875db | 2017-07-20 10:47:53 +0800 | [diff] [blame] | 244 | if not isinstance(other, ExtendedFloat): |
| 245 | other = ExtendedFloat(other) |
Kuang-che Wu | d3ff586 | 2019-11-26 11:41:05 +0800 | [diff] [blame] | 246 | return self.mantissa == other.mantissa and self.exponent == other.exponent |
| 247 | |
| 248 | def __lt__(self, other): |
| 249 | if not isinstance(other, ExtendedFloat): |
| 250 | other = ExtendedFloat(other) |
| 251 | if self.mantissa * other.mantissa <= 0: |
| 252 | return self.mantissa < other.mantissa |
| 253 | |
| 254 | if self.exponent == other.exponent: |
| 255 | return self.mantissa < other.mantissa |
| 256 | |
| 257 | if self.mantissa > 0: |
| 258 | return self.exponent < other.exponent |
| 259 | return self.exponent > other.exponent |
| 260 | |
Kuang-che Wu | 020a118 | 2020-09-08 17:17:22 +0800 | [diff] [blame] | 261 | def __round__(self, digits): |
| 262 | return round(float(self), digits) |