Kuang-che Wu | 6e4beca | 2018-06-27 17:45:02 +0800 | [diff] [blame] | 1 | # -*- coding: utf-8 -*- |
Kuang-che Wu | 88875db | 2017-07-20 10:47:53 +0800 | [diff] [blame] | 2 | # Copyright 2017 The Chromium OS Authors. All rights reserved. |
| 3 | # Use of this source code is governed by a BSD-style license that can be |
| 4 | # found in the LICENSE file. |
| 5 | """Math utilities.""" |
| 6 | |
| 7 | from __future__ import print_function |
| 8 | |
| 9 | import copy |
| 10 | import math |
| 11 | import sys |
| 12 | |
| 13 | |
Kuang-che Wu | 328da39 | 2018-09-12 17:20:13 +0800 | [diff] [blame] | 14 | def log(x): |
Kuang-che Wu | 88875db | 2017-07-20 10:47:53 +0800 | [diff] [blame] | 15 | """Wrapper of log(x) handling zero and negative value.""" |
| 16 | # Due to arithmetic error, x may become tiny negative. |
Kuang-che Wu | 328da39 | 2018-09-12 17:20:13 +0800 | [diff] [blame] | 17 | # -1e-10 is not enough, see strategy_test.test_next_idx_arithmetic_error for |
| 18 | # detail. |
| 19 | if 0 >= x > -1e-9: |
Kuang-che Wu | 88875db | 2017-07-20 10:47:53 +0800 | [diff] [blame] | 20 | return 0.0 |
| 21 | return math.log(x) |
| 22 | |
| 23 | |
| 24 | def _xlogx(x): |
Kuang-che Wu | 328da39 | 2018-09-12 17:20:13 +0800 | [diff] [blame] | 25 | return x * log(x) |
Kuang-che Wu | 88875db | 2017-07-20 10:47:53 +0800 | [diff] [blame] | 26 | |
| 27 | |
| 28 | def least(values): |
| 29 | """Finds minimum non-zero value. |
| 30 | |
| 31 | Args: |
| 32 | values: Non-negative values. |
| 33 | |
| 34 | Returns: |
| 35 | The minimum non-zero value. |
| 36 | |
| 37 | Raises: |
| 38 | ValueError: if all values are zero or negative. |
| 39 | """ |
| 40 | return min(x for x in values if x > 0) |
| 41 | |
| 42 | |
| 43 | class EntropyStats(object): |
| 44 | r"""A data structure to maintain cross entropy of a set of values. |
| 45 | |
| 46 | This is a math helper for NoisyBinarySearch. Given a set of probability, |
| 47 | this class can incrementally calculate their cross entropy. |
| 48 | |
| 49 | Algorithm: |
| 50 | Given unnormalized p_i, their cross entropy could be expressed as |
| 51 | Let S = \sum_i p_i |
| 52 | Normalized(p): p_i' = p_i/S |
| 53 | CrossEntropy(p) = \sum_i { -(p_i/S) * log (p_i/S) } |
| 54 | = \sum_i { -(p_i/S) (log p_i - log S) } |
| 55 | = \sum_i { -p_i/S * log p_i } + \sum_i { p_i/S * log S } |
| 56 | = -1/S * \sum_i { p_i*log p_i } + log S |
| 57 | = log S - 1/S * \sum_i { p_i*log p_i } |
| 58 | |
| 59 | So we can maintain sum=|S| and sum_log=|\sum_i { p_i*log p_i }| for |
| 60 | incremental update. |
| 61 | """ |
| 62 | |
| 63 | def __init__(self, p): |
| 64 | """Initializes EntropyStats. |
| 65 | |
| 66 | Args: |
Kuang-che Wu | c90dbf1 | 2018-07-19 16:40:14 +0800 | [diff] [blame] | 67 | p: A list of probability value. Could be unnormalized. |
Kuang-che Wu | 88875db | 2017-07-20 10:47:53 +0800 | [diff] [blame] | 68 | """ |
| 69 | self.sum = sum(p) |
| 70 | self.sum_log = sum(map(_xlogx, p)) |
| 71 | |
| 72 | def entropy(self): |
| 73 | """Returns cross entropy.""" |
| 74 | if self.sum == 0: |
| 75 | return 0 |
Kuang-che Wu | 328da39 | 2018-09-12 17:20:13 +0800 | [diff] [blame] | 76 | return log(self.sum) - self.sum_log / self.sum |
Kuang-che Wu | 88875db | 2017-07-20 10:47:53 +0800 | [diff] [blame] | 77 | |
| 78 | def replace(self, old, new): |
| 79 | """Replaces one random variable in the collection with another. |
| 80 | |
| 81 | Args: |
| 82 | old: original value to be replaced. |
| 83 | new: new value |
| 84 | """ |
| 85 | self.sum += new - old |
| 86 | self.sum_log += _xlogx(new) - _xlogx(old) |
| 87 | |
| 88 | def multiply(self, value): |
| 89 | """Multiplies all values in the collection by |value|. |
| 90 | |
| 91 | Returns: |
| 92 | A new instance of EntropyStats with calculated value. |
| 93 | """ |
| 94 | other = copy.copy(self) |
| 95 | # \sum_i { (p_i*value) * log (p_i*value) } |
| 96 | # = \sum_i { (p_i*value) * (log p_i + log value) } |
| 97 | # = value * \sum_i { p_i log p_i } + \sum_i { p_i * value * log value } |
| 98 | # = value * \sum_i { p_i log p_i } + \sum_i { p_i } * value * log value |
| 99 | other.sum_log = value * self.sum_log + self.sum * _xlogx(value) |
| 100 | |
| 101 | other.sum *= value |
| 102 | return other |
| 103 | |
| 104 | |
| 105 | class ExtendedFloat(object): |
| 106 | """Custom floating point number with larger range of exponent. |
| 107 | |
| 108 | Problem: |
| 109 | In the formula of noisy binary search algorithm, we need to calculate p**n |
| 110 | (where n is test count) and normalize probabilities. When n is hundreds or |
| 111 | more, p**n goes too small and becomes zero (underflow), the normalization |
| 112 | step will fail. |
| 113 | |
| 114 | Since n > 100 is not unreasonable large for noisy bisection, we need a |
| 115 | numeric type less likely or never underflow. |
| 116 | |
| 117 | Solution: |
| 118 | Supports a floating number is represented as |
| 119 | x = mantissa * 2.**exponent |
| 120 | We use python's float to represent `mantissa` and python's int to represent |
| 121 | `exponent`. |
| 122 | |
| 123 | For mantissa, we can accept calculated probability is not super accurate |
| 124 | because the worst case is just running test one or two more times. So we |
| 125 | keep mantissa in python's built-in float. |
| 126 | |
| 127 | For exponent, we represent it using python's int, which is virtually |
| 128 | unlimited digits. |
| 129 | |
| 130 | Alternative solutions: |
| 131 | 1. Maybe we can rewrite the formula to avoid underflow. But I want to keep |
| 132 | the arithmetic code simple. |
| 133 | |
| 134 | 2. Use python stdlib's decimal or fractions. But they are very slow -- |
| 135 | decimal is 6-50x slower, fractions is 9-760x slower while ExtendedFloat |
| 136 | is about 1.2-3.3x slower [1]. |
| 137 | |
| 138 | [1] test setup: easy N=1000, oracle=(0.01, 0.2) |
| 139 | hard N=1000, oracle=(0.45, 0.55) |
| 140 | |
| 141 | Attributes: |
| 142 | mantissa: The mantissa part of floating number. Its range is 0.5 <= |
| 143 | abs(mantissa) < 1.0. |
| 144 | exponent: The exponent part of floating number. |
| 145 | """ |
| 146 | |
| 147 | def __init__(self, value, exp=0): |
| 148 | self.mantissa, self.exponent = math.frexp(value) |
| 149 | self.exponent = self.exponent + exp if value else 0 |
| 150 | |
| 151 | def __float__(self): |
| 152 | return math.ldexp(self.mantissa, self.exponent) |
| 153 | |
| 154 | def __neg__(self): |
| 155 | return ExtendedFloat(-self.mantissa, self.exponent) |
| 156 | |
| 157 | def __abs__(self): |
| 158 | return ExtendedFloat(abs(self.mantissa), self.exponent) |
| 159 | |
| 160 | def __add__(self, other): |
| 161 | if not isinstance(other, ExtendedFloat): |
| 162 | other = ExtendedFloat(other) |
| 163 | if self.mantissa == 0: |
| 164 | return other |
| 165 | if self.exponent < other.exponent: |
| 166 | return other.__add__(self) |
| 167 | value = math.ldexp(other.mantissa, other.exponent - self.exponent) |
| 168 | return ExtendedFloat(value + self.mantissa, self.exponent) |
| 169 | |
| 170 | __radd__ = __add__ |
| 171 | |
| 172 | def __pow__(self, p): |
| 173 | # Because 0.5 <= abs(mantissa) < 1.0, using float is enough without |
| 174 | # underflow for small p. |
| 175 | if p < -sys.float_info.min_exp: |
| 176 | return ExtendedFloat(self.mantissa**p, self.exponent * p) |
| 177 | |
| 178 | half_pow = self.__pow__(p >> 1) |
| 179 | result = half_pow.__mul__(half_pow) |
| 180 | if p & 1: |
| 181 | result = result.__mul__(self) |
| 182 | return result |
| 183 | |
| 184 | def __sub__(self, other): |
| 185 | return self.__add__(-other) |
| 186 | |
| 187 | def __rsub__(self, other): |
| 188 | return ExtendedFloat(other).__add__(-self) |
| 189 | |
| 190 | def __mul__(self, other): |
| 191 | if not isinstance(other, ExtendedFloat): |
| 192 | other = ExtendedFloat(other) |
| 193 | return ExtendedFloat(self.mantissa * other.mantissa, |
| 194 | self.exponent + other.exponent) |
| 195 | |
| 196 | __rmul__ = __mul__ |
| 197 | |
| 198 | def __div__(self, other): |
| 199 | if not isinstance(other, ExtendedFloat): |
| 200 | other = ExtendedFloat(other) |
| 201 | return ExtendedFloat(self.mantissa / other.mantissa, |
| 202 | self.exponent - other.exponent) |
| 203 | |
| 204 | def __rdiv__(self, other): |
| 205 | return ExtendedFloat(other).__div__(self) |
| 206 | |
| 207 | def __repr__(self): |
| 208 | return "ExtendedFloat(%f, %d)" % (self.mantissa, self.exponent) |
| 209 | |
| 210 | def __cmp__(self, other): |
| 211 | if not isinstance(other, ExtendedFloat): |
| 212 | other = ExtendedFloat(other) |
| 213 | diff = self - other |
| 214 | return cmp(diff.mantissa, 0) |