Shahbaz Youssefi | 7fa9e74 | 2022-02-02 10:33:18 -0500 | [diff] [blame] | 1 | // Copyright (c) 2022 Google LLC. |
| 2 | // |
| 3 | // Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | // you may not use this file except in compliance with the License. |
| 5 | // You may obtain a copy of the License at |
| 6 | // |
| 7 | // http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | // |
| 9 | // Unless required by applicable law or agreed to in writing, software |
| 10 | // distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | // See the License for the specific language governing permissions and |
| 13 | // limitations under the License. |
| 14 | |
| 15 | #ifndef SOURCE_DIFF_LCS_H_ |
| 16 | #define SOURCE_DIFF_LCS_H_ |
| 17 | |
| 18 | #include <algorithm> |
| 19 | #include <cassert> |
| 20 | #include <cstddef> |
| 21 | #include <functional> |
| 22 | #include <vector> |
| 23 | |
| 24 | namespace spvtools { |
| 25 | namespace diff { |
| 26 | |
| 27 | // The result of a diff. |
| 28 | using DiffMatch = std::vector<bool>; |
| 29 | |
| 30 | // Helper class to find the longest common subsequence between two function |
| 31 | // bodies. |
| 32 | template <typename Sequence> |
| 33 | class LongestCommonSubsequence { |
| 34 | public: |
| 35 | LongestCommonSubsequence(const Sequence& src, const Sequence& dst) |
| 36 | : src_(src), |
| 37 | dst_(dst), |
| 38 | table_(src.size(), std::vector<DiffMatchEntry>(dst.size())) {} |
| 39 | |
| 40 | // Given two sequences, it creates a matching between them. The elements are |
| 41 | // simply marked as matched in src and dst, with any unmatched element in src |
| 42 | // implying a removal and any unmatched element in dst implying an addition. |
| 43 | // |
| 44 | // Returns the length of the longest common subsequence. |
| 45 | template <typename T> |
| 46 | size_t Get(std::function<bool(T src_elem, T dst_elem)> match, |
| 47 | DiffMatch* src_match_result, DiffMatch* dst_match_result); |
| 48 | |
| 49 | private: |
| 50 | template <typename T> |
| 51 | size_t CalculateLCS(size_t src_start, size_t dst_start, |
| 52 | std::function<bool(T src_elem, T dst_elem)> match); |
| 53 | void RetrieveMatch(DiffMatch* src_match_result, DiffMatch* dst_match_result); |
| 54 | bool IsInBound(size_t src_index, size_t dst_index) { |
| 55 | return src_index < src_.size() && dst_index < dst_.size(); |
| 56 | } |
| 57 | bool IsCalculated(size_t src_index, size_t dst_index) { |
| 58 | assert(IsInBound(src_index, dst_index)); |
| 59 | return table_[src_index][dst_index].valid; |
| 60 | } |
| 61 | size_t GetMemoizedLength(size_t src_index, size_t dst_index) { |
| 62 | if (!IsInBound(src_index, dst_index)) { |
| 63 | return 0; |
| 64 | } |
| 65 | assert(IsCalculated(src_index, dst_index)); |
| 66 | return table_[src_index][dst_index].best_match_length; |
| 67 | } |
| 68 | bool IsMatched(size_t src_index, size_t dst_index) { |
| 69 | assert(IsCalculated(src_index, dst_index)); |
| 70 | return table_[src_index][dst_index].matched; |
| 71 | } |
| 72 | |
| 73 | const Sequence& src_; |
| 74 | const Sequence& dst_; |
| 75 | |
| 76 | struct DiffMatchEntry { |
| 77 | size_t best_match_length = 0; |
| 78 | // Whether src[i] and dst[j] matched. This is an optimization to avoid |
| 79 | // calling the `match` function again when walking the LCS table. |
| 80 | bool matched = false; |
| 81 | // Use for the recursive algorithm to know if the contents of this entry are |
| 82 | // valid. |
| 83 | bool valid = false; |
| 84 | }; |
| 85 | |
| 86 | std::vector<std::vector<DiffMatchEntry>> table_; |
| 87 | }; |
| 88 | |
| 89 | template <typename Sequence> |
| 90 | template <typename T> |
| 91 | size_t LongestCommonSubsequence<Sequence>::Get( |
| 92 | std::function<bool(T src_elem, T dst_elem)> match, |
| 93 | DiffMatch* src_match_result, DiffMatch* dst_match_result) { |
| 94 | size_t best_match_length = CalculateLCS(0, 0, match); |
| 95 | RetrieveMatch(src_match_result, dst_match_result); |
| 96 | return best_match_length; |
| 97 | } |
| 98 | |
| 99 | template <typename Sequence> |
| 100 | template <typename T> |
| 101 | size_t LongestCommonSubsequence<Sequence>::CalculateLCS( |
| 102 | size_t src_start, size_t dst_start, |
| 103 | std::function<bool(T src_elem, T dst_elem)> match) { |
| 104 | // The LCS algorithm is simple. Given sequences s and d, with a:b depicting a |
| 105 | // range in python syntax: |
| 106 | // |
| 107 | // lcs(s[i:], d[j:]) = |
| 108 | // lcs(s[i+1:], d[j+1:]) + 1 if s[i] == d[j] |
| 109 | // max(lcs(s[i+1:], d[j:]), lcs(s[i:], d[j+1:])) o.w. |
| 110 | // |
| 111 | // Once the LCS table is filled according to the above, it can be walked and |
| 112 | // the best match retrieved. |
| 113 | // |
| 114 | // This is a recursive function with memoization, which avoids filling table |
| 115 | // entries where unnecessary. This makes the best case O(N) instead of |
| 116 | // O(N^2). |
| 117 | |
| 118 | // To avoid unnecessary recursion on long sequences, process a whole strip of |
| 119 | // matching elements in one go. |
| 120 | size_t src_cur = src_start; |
| 121 | size_t dst_cur = dst_start; |
| 122 | while (IsInBound(src_cur, dst_cur) && !IsCalculated(src_cur, dst_cur) && |
| 123 | match(src_[src_cur], dst_[dst_cur])) { |
| 124 | ++src_cur; |
| 125 | ++dst_cur; |
| 126 | } |
| 127 | |
| 128 | // We've reached a pair of elements that don't match. Recursively determine |
| 129 | // which one should be left unmatched. |
| 130 | size_t best_match_length = 0; |
| 131 | if (IsInBound(src_cur, dst_cur)) { |
| 132 | if (IsCalculated(src_cur, dst_cur)) { |
| 133 | best_match_length = GetMemoizedLength(src_cur, dst_cur); |
| 134 | } else { |
| 135 | best_match_length = std::max(CalculateLCS(src_cur + 1, dst_cur, match), |
| 136 | CalculateLCS(src_cur, dst_cur + 1, match)); |
| 137 | |
| 138 | // Fill the table with this information |
| 139 | DiffMatchEntry& entry = table_[src_cur][dst_cur]; |
| 140 | assert(!entry.valid); |
| 141 | entry.best_match_length = best_match_length; |
| 142 | entry.valid = true; |
| 143 | } |
| 144 | } |
| 145 | |
| 146 | // Go over the matched strip and update the table as well. |
| 147 | assert(src_cur - src_start == dst_cur - dst_start); |
| 148 | size_t contiguous_match_len = src_cur - src_start; |
| 149 | |
| 150 | for (size_t i = 0; i < contiguous_match_len; ++i) { |
| 151 | --src_cur; |
| 152 | --dst_cur; |
| 153 | assert(IsInBound(src_cur, dst_cur)); |
| 154 | |
| 155 | DiffMatchEntry& entry = table_[src_cur][dst_cur]; |
| 156 | assert(!entry.valid); |
| 157 | entry.best_match_length = ++best_match_length; |
| 158 | entry.matched = true; |
| 159 | entry.valid = true; |
| 160 | } |
| 161 | |
| 162 | return best_match_length; |
| 163 | } |
| 164 | |
| 165 | template <typename Sequence> |
| 166 | void LongestCommonSubsequence<Sequence>::RetrieveMatch( |
| 167 | DiffMatch* src_match_result, DiffMatch* dst_match_result) { |
| 168 | src_match_result->clear(); |
| 169 | dst_match_result->clear(); |
| 170 | |
| 171 | src_match_result->resize(src_.size(), false); |
| 172 | dst_match_result->resize(dst_.size(), false); |
| 173 | |
| 174 | size_t src_cur = 0; |
| 175 | size_t dst_cur = 0; |
| 176 | while (IsInBound(src_cur, dst_cur)) { |
| 177 | if (IsMatched(src_cur, dst_cur)) { |
| 178 | (*src_match_result)[src_cur++] = true; |
| 179 | (*dst_match_result)[dst_cur++] = true; |
| 180 | continue; |
| 181 | } |
| 182 | |
| 183 | if (GetMemoizedLength(src_cur + 1, dst_cur) >= |
| 184 | GetMemoizedLength(src_cur, dst_cur + 1)) { |
| 185 | ++src_cur; |
| 186 | } else { |
| 187 | ++dst_cur; |
| 188 | } |
| 189 | } |
| 190 | } |
| 191 | |
| 192 | } // namespace diff |
| 193 | } // namespace spvtools |
| 194 | |
| 195 | #endif // SOURCE_DIFF_LCS_H_ |