blob: 6c00e86499ee6fa75170ffacfcbf9a50cb33ac5e [file] [log] [blame]
Shahbaz Youssefi7fa9e742022-02-02 10:33:18 -05001// Copyright (c) 2022 Google LLC.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#ifndef SOURCE_DIFF_LCS_H_
16#define SOURCE_DIFF_LCS_H_
17
18#include <algorithm>
19#include <cassert>
20#include <cstddef>
Shahbaz Youssefi9beb5452022-02-07 09:37:04 -050021#include <cstdint>
Shahbaz Youssefi7fa9e742022-02-02 10:33:18 -050022#include <functional>
Shahbaz Youssefi9beb5452022-02-07 09:37:04 -050023#include <stack>
Shahbaz Youssefi7fa9e742022-02-02 10:33:18 -050024#include <vector>
25
26namespace spvtools {
27namespace diff {
28
29// The result of a diff.
30using DiffMatch = std::vector<bool>;
31
32// Helper class to find the longest common subsequence between two function
33// bodies.
34template <typename Sequence>
35class LongestCommonSubsequence {
36 public:
37 LongestCommonSubsequence(const Sequence& src, const Sequence& dst)
38 : src_(src),
39 dst_(dst),
40 table_(src.size(), std::vector<DiffMatchEntry>(dst.size())) {}
41
42 // Given two sequences, it creates a matching between them. The elements are
43 // simply marked as matched in src and dst, with any unmatched element in src
44 // implying a removal and any unmatched element in dst implying an addition.
45 //
46 // Returns the length of the longest common subsequence.
47 template <typename T>
Shahbaz Youssefi9beb5452022-02-07 09:37:04 -050048 uint32_t Get(std::function<bool(T src_elem, T dst_elem)> match,
49 DiffMatch* src_match_result, DiffMatch* dst_match_result);
Shahbaz Youssefi7fa9e742022-02-02 10:33:18 -050050
51 private:
Shahbaz Youssefi9beb5452022-02-07 09:37:04 -050052 struct DiffMatchIndex {
53 uint32_t src_offset;
54 uint32_t dst_offset;
55 };
56
Shahbaz Youssefi7fa9e742022-02-02 10:33:18 -050057 template <typename T>
Shahbaz Youssefi9beb5452022-02-07 09:37:04 -050058 void CalculateLCS(std::function<bool(T src_elem, T dst_elem)> match);
Shahbaz Youssefi7fa9e742022-02-02 10:33:18 -050059 void RetrieveMatch(DiffMatch* src_match_result, DiffMatch* dst_match_result);
Shahbaz Youssefi9beb5452022-02-07 09:37:04 -050060 bool IsInBound(DiffMatchIndex index) {
61 return index.src_offset < src_.size() && index.dst_offset < dst_.size();
Shahbaz Youssefi7fa9e742022-02-02 10:33:18 -050062 }
Shahbaz Youssefi9beb5452022-02-07 09:37:04 -050063 bool IsCalculated(DiffMatchIndex index) {
64 assert(IsInBound(index));
65 return table_[index.src_offset][index.dst_offset].valid;
Shahbaz Youssefi7fa9e742022-02-02 10:33:18 -050066 }
Shahbaz Youssefi9beb5452022-02-07 09:37:04 -050067 bool IsCalculatedOrOutOfBound(DiffMatchIndex index) {
68 return !IsInBound(index) || IsCalculated(index);
69 }
70 uint32_t GetMemoizedLength(DiffMatchIndex index) {
71 if (!IsInBound(index)) {
Shahbaz Youssefi7fa9e742022-02-02 10:33:18 -050072 return 0;
73 }
Shahbaz Youssefi9beb5452022-02-07 09:37:04 -050074 assert(IsCalculated(index));
75 return table_[index.src_offset][index.dst_offset].best_match_length;
Shahbaz Youssefi7fa9e742022-02-02 10:33:18 -050076 }
Shahbaz Youssefi9beb5452022-02-07 09:37:04 -050077 bool IsMatched(DiffMatchIndex index) {
78 assert(IsCalculated(index));
79 return table_[index.src_offset][index.dst_offset].matched;
80 }
81 void MarkMatched(DiffMatchIndex index, uint32_t best_match_length,
82 bool matched) {
83 assert(IsInBound(index));
84 DiffMatchEntry& entry = table_[index.src_offset][index.dst_offset];
85 assert(!entry.valid);
86
87 entry.best_match_length = best_match_length & 0x3FFFFFFF;
88 assert(entry.best_match_length == best_match_length);
89 entry.matched = matched;
90 entry.valid = true;
Shahbaz Youssefi7fa9e742022-02-02 10:33:18 -050091 }
92
93 const Sequence& src_;
94 const Sequence& dst_;
95
96 struct DiffMatchEntry {
Shahbaz Youssefi9beb5452022-02-07 09:37:04 -050097 DiffMatchEntry() : best_match_length(0), matched(false), valid(false) {}
98
99 uint32_t best_match_length : 30;
Shahbaz Youssefi7fa9e742022-02-02 10:33:18 -0500100 // Whether src[i] and dst[j] matched. This is an optimization to avoid
101 // calling the `match` function again when walking the LCS table.
Shahbaz Youssefi9beb5452022-02-07 09:37:04 -0500102 uint32_t matched : 1;
Shahbaz Youssefi7fa9e742022-02-02 10:33:18 -0500103 // Use for the recursive algorithm to know if the contents of this entry are
104 // valid.
Shahbaz Youssefi9beb5452022-02-07 09:37:04 -0500105 uint32_t valid : 1;
Shahbaz Youssefi7fa9e742022-02-02 10:33:18 -0500106 };
107
108 std::vector<std::vector<DiffMatchEntry>> table_;
109};
110
111template <typename Sequence>
112template <typename T>
Shahbaz Youssefi9beb5452022-02-07 09:37:04 -0500113uint32_t LongestCommonSubsequence<Sequence>::Get(
Shahbaz Youssefi7fa9e742022-02-02 10:33:18 -0500114 std::function<bool(T src_elem, T dst_elem)> match,
115 DiffMatch* src_match_result, DiffMatch* dst_match_result) {
Shahbaz Youssefi9beb5452022-02-07 09:37:04 -0500116 CalculateLCS(match);
Shahbaz Youssefi7fa9e742022-02-02 10:33:18 -0500117 RetrieveMatch(src_match_result, dst_match_result);
Shahbaz Youssefi9beb5452022-02-07 09:37:04 -0500118 return GetMemoizedLength({0, 0});
Shahbaz Youssefi7fa9e742022-02-02 10:33:18 -0500119}
120
121template <typename Sequence>
122template <typename T>
Shahbaz Youssefi9beb5452022-02-07 09:37:04 -0500123void LongestCommonSubsequence<Sequence>::CalculateLCS(
Shahbaz Youssefi7fa9e742022-02-02 10:33:18 -0500124 std::function<bool(T src_elem, T dst_elem)> match) {
125 // The LCS algorithm is simple. Given sequences s and d, with a:b depicting a
126 // range in python syntax:
127 //
128 // lcs(s[i:], d[j:]) =
129 // lcs(s[i+1:], d[j+1:]) + 1 if s[i] == d[j]
130 // max(lcs(s[i+1:], d[j:]), lcs(s[i:], d[j+1:])) o.w.
131 //
132 // Once the LCS table is filled according to the above, it can be walked and
133 // the best match retrieved.
134 //
135 // This is a recursive function with memoization, which avoids filling table
136 // entries where unnecessary. This makes the best case O(N) instead of
Shahbaz Youssefi9beb5452022-02-07 09:37:04 -0500137 // O(N^2). The implemention uses a std::stack to avoid stack overflow on long
138 // sequences.
Shahbaz Youssefi7fa9e742022-02-02 10:33:18 -0500139
Shahbaz Youssefi9beb5452022-02-07 09:37:04 -0500140 if (src_.empty() || dst_.empty()) {
141 return;
Shahbaz Youssefi7fa9e742022-02-02 10:33:18 -0500142 }
143
Shahbaz Youssefi9beb5452022-02-07 09:37:04 -0500144 std::stack<DiffMatchIndex> to_calculate;
145 to_calculate.push({0, 0});
Shahbaz Youssefi7fa9e742022-02-02 10:33:18 -0500146
Shahbaz Youssefi9beb5452022-02-07 09:37:04 -0500147 while (!to_calculate.empty()) {
148 DiffMatchIndex current = to_calculate.top();
149 to_calculate.pop();
150 assert(IsInBound(current));
151
152 // If already calculated through another path, ignore it.
153 if (IsCalculated(current)) {
154 continue;
155 }
156
157 if (match(src_[current.src_offset], dst_[current.dst_offset])) {
158 // If the current elements match, advance both indices and calculate the
159 // LCS if not already. Visit `current` again afterwards, so its
160 // corresponding entry will be updated.
161 DiffMatchIndex next = {current.src_offset + 1, current.dst_offset + 1};
162 if (IsCalculatedOrOutOfBound(next)) {
163 MarkMatched(current, GetMemoizedLength(next) + 1, true);
164 } else {
165 to_calculate.push(current);
166 to_calculate.push(next);
167 }
168 continue;
169 }
170
171 // We've reached a pair of elements that don't match. Calculate the LCS for
172 // both cases of either being left unmatched and take the max. Visit
173 // `current` again afterwards, so its corresponding entry will be updated.
174 DiffMatchIndex next_src = {current.src_offset + 1, current.dst_offset};
175 DiffMatchIndex next_dst = {current.src_offset, current.dst_offset + 1};
176
177 if (IsCalculatedOrOutOfBound(next_src) &&
178 IsCalculatedOrOutOfBound(next_dst)) {
179 uint32_t best_match_length =
180 std::max(GetMemoizedLength(next_src), GetMemoizedLength(next_dst));
181 MarkMatched(current, best_match_length, false);
182 continue;
183 }
184
185 to_calculate.push(current);
186 if (!IsCalculatedOrOutOfBound(next_src)) {
187 to_calculate.push(next_src);
188 }
189 if (!IsCalculatedOrOutOfBound(next_dst)) {
190 to_calculate.push(next_dst);
Shahbaz Youssefi7fa9e742022-02-02 10:33:18 -0500191 }
192 }
Shahbaz Youssefi7fa9e742022-02-02 10:33:18 -0500193}
194
195template <typename Sequence>
196void LongestCommonSubsequence<Sequence>::RetrieveMatch(
197 DiffMatch* src_match_result, DiffMatch* dst_match_result) {
198 src_match_result->clear();
199 dst_match_result->clear();
200
201 src_match_result->resize(src_.size(), false);
202 dst_match_result->resize(dst_.size(), false);
203
Shahbaz Youssefi9beb5452022-02-07 09:37:04 -0500204 DiffMatchIndex current = {0, 0};
205 while (IsInBound(current)) {
206 if (IsMatched(current)) {
207 (*src_match_result)[current.src_offset++] = true;
208 (*dst_match_result)[current.dst_offset++] = true;
Shahbaz Youssefi7fa9e742022-02-02 10:33:18 -0500209 continue;
210 }
211
Shahbaz Youssefi9beb5452022-02-07 09:37:04 -0500212 if (GetMemoizedLength({current.src_offset + 1, current.dst_offset}) >=
213 GetMemoizedLength({current.src_offset, current.dst_offset + 1})) {
214 ++current.src_offset;
Shahbaz Youssefi7fa9e742022-02-02 10:33:18 -0500215 } else {
Shahbaz Youssefi9beb5452022-02-07 09:37:04 -0500216 ++current.dst_offset;
Shahbaz Youssefi7fa9e742022-02-02 10:33:18 -0500217 }
218 }
219}
220
221} // namespace diff
222} // namespace spvtools
223
224#endif // SOURCE_DIFF_LCS_H_