blob: d71323efb0fd36c040b3153d07c815fb3ab6b8ac [file] [log] [blame]
Yves Gerey890f62b2019-04-10 17:18:48 +02001/*
2 * Copyright (c) 2019 The WebRTC project authors. All Rights Reserved.
3 *
4 * Use of this source code is governed by a BSD-style license
5 * that can be found in the LICENSE file in the root of the source
6 * tree. An additional intellectual property rights grant can be found
7 * in the file PATENTS. All contributing project authors may
8 * be found in the AUTHORS file in the root of the source tree.
9 */
10
11#ifndef RTC_BASE_NUMERICS_RUNNING_STATISTICS_H_
12#define RTC_BASE_NUMERICS_RUNNING_STATISTICS_H_
13
14#include <algorithm>
15#include <cmath>
16#include <limits>
17
18#include "absl/types/optional.h"
19
20#include "rtc_base/numerics/math_utils.h"
21
22namespace webrtc {
23
24// tl;dr: Robust and efficient online computation of statistics,
25// using Welford's method for variance. [1]
26//
27// This should be your go-to class if you ever need to compute
28// min, max, mean, variance and standard deviation.
29// If you need to get percentiles, please use webrtc::SamplesStatsCounter.
30//
31// The measures return absl::nullopt if no samples were fed (Size() == 0),
32// otherwise the returned optional is guaranteed to contain a value.
33//
34// [1]
35// https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm
36
37// The type T is a scalar which must be convertible to double.
38// Rationale: we often need greater precision for measures
39// than for the samples themselves.
40template <typename T>
41class RunningStatistics {
42 public:
43 // Update stats ////////////////////////////////////////////
44
45 // Add a value participating in the statistics in O(1) time.
46 void AddSample(T sample) {
47 max_ = std::max(max_, sample);
48 min_ = std::min(min_, sample);
49 ++size_;
50 // Welford's incremental update.
51 const double delta = sample - mean_;
52 mean_ += delta / size_;
53 const double delta2 = sample - mean_;
54 cumul_ += delta * delta2;
55 }
56
57 // Merge other stats, as if samples were added one by one, but in O(1).
58 void MergeStatistics(const RunningStatistics<T>& other) {
59 if (other.size_ == 0) {
60 return;
61 }
62 max_ = std::max(max_, other.max_);
63 min_ = std::min(min_, other.min_);
64 const int64_t new_size = size_ + other.size_;
65 const double new_mean =
66 (mean_ * size_ + other.mean_ * other.size_) / new_size;
67 // Each cumulant must be corrected.
68 // * from: sum((x_i - mean_)²)
69 // * to: sum((x_i - new_mean)²)
70 auto delta = [new_mean](const RunningStatistics<T>& stats) {
71 return stats.size_ * (new_mean * (new_mean - 2 * stats.mean_) +
72 stats.mean_ * stats.mean_);
73 };
74 cumul_ = cumul_ + delta(*this) + other.cumul_ + delta(other);
75 mean_ = new_mean;
76 size_ = new_size;
77 }
78
79 // Get Measures ////////////////////////////////////////////
80
81 // Returns number of samples involved,
82 // that is number of times AddSample() was called.
83 int64_t Size() const { return size_; }
84
85 // Returns min in O(1) time.
86 absl::optional<T> GetMin() const {
87 if (size_ == 0) {
88 return absl::nullopt;
89 }
90 return min_;
91 }
92
93 // Returns max in O(1) time.
94 absl::optional<T> GetMax() const {
95 if (size_ == 0) {
96 return absl::nullopt;
97 }
98 return max_;
99 }
100
101 // Returns mean in O(1) time.
102 absl::optional<double> GetMean() const {
103 if (size_ == 0) {
104 return absl::nullopt;
105 }
106 return mean_;
107 }
108
109 // Returns unbiased sample variance in O(1) time.
110 absl::optional<double> GetVariance() const {
111 if (size_ == 0) {
112 return absl::nullopt;
113 }
114 return cumul_ / size_;
115 }
116
117 // Returns unbiased standard deviation in O(1) time.
118 absl::optional<double> GetStandardDeviation() const {
119 if (size_ == 0) {
120 return absl::nullopt;
121 }
122 return std::sqrt(*GetVariance());
123 }
124
125 private:
126 int64_t size_ = 0; // Samples seen.
127 T min_ = infinity_or_max<T>();
128 T max_ = minus_infinity_or_min<T>();
129 double mean_ = 0;
130 double cumul_ = 0; // Variance * size_, sometimes noted m2.
131};
132
133} // namespace webrtc
134
135#endif // RTC_BASE_NUMERICS_RUNNING_STATISTICS_H_