blob: 486f43dc4d3f441fb95f0514c5078558e47396d8 [file] [log] [blame]
Shahbaz Youssefi7fa9e742022-02-02 10:33:18 -05001// Copyright (c) 2022 Google LLC.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#ifndef SOURCE_DIFF_LCS_H_
16#define SOURCE_DIFF_LCS_H_
17
18#include <algorithm>
19#include <cassert>
20#include <cstddef>
21#include <functional>
22#include <vector>
23
24namespace spvtools {
25namespace diff {
26
27// The result of a diff.
28using DiffMatch = std::vector<bool>;
29
30// Helper class to find the longest common subsequence between two function
31// bodies.
32template <typename Sequence>
33class LongestCommonSubsequence {
34 public:
35 LongestCommonSubsequence(const Sequence& src, const Sequence& dst)
36 : src_(src),
37 dst_(dst),
38 table_(src.size(), std::vector<DiffMatchEntry>(dst.size())) {}
39
40 // Given two sequences, it creates a matching between them. The elements are
41 // simply marked as matched in src and dst, with any unmatched element in src
42 // implying a removal and any unmatched element in dst implying an addition.
43 //
44 // Returns the length of the longest common subsequence.
45 template <typename T>
46 size_t Get(std::function<bool(T src_elem, T dst_elem)> match,
47 DiffMatch* src_match_result, DiffMatch* dst_match_result);
48
49 private:
50 template <typename T>
51 size_t CalculateLCS(size_t src_start, size_t dst_start,
52 std::function<bool(T src_elem, T dst_elem)> match);
53 void RetrieveMatch(DiffMatch* src_match_result, DiffMatch* dst_match_result);
54 bool IsInBound(size_t src_index, size_t dst_index) {
55 return src_index < src_.size() && dst_index < dst_.size();
56 }
57 bool IsCalculated(size_t src_index, size_t dst_index) {
58 assert(IsInBound(src_index, dst_index));
59 return table_[src_index][dst_index].valid;
60 }
61 size_t GetMemoizedLength(size_t src_index, size_t dst_index) {
62 if (!IsInBound(src_index, dst_index)) {
63 return 0;
64 }
65 assert(IsCalculated(src_index, dst_index));
66 return table_[src_index][dst_index].best_match_length;
67 }
68 bool IsMatched(size_t src_index, size_t dst_index) {
69 assert(IsCalculated(src_index, dst_index));
70 return table_[src_index][dst_index].matched;
71 }
72
73 const Sequence& src_;
74 const Sequence& dst_;
75
76 struct DiffMatchEntry {
77 size_t best_match_length = 0;
78 // Whether src[i] and dst[j] matched. This is an optimization to avoid
79 // calling the `match` function again when walking the LCS table.
80 bool matched = false;
81 // Use for the recursive algorithm to know if the contents of this entry are
82 // valid.
83 bool valid = false;
84 };
85
86 std::vector<std::vector<DiffMatchEntry>> table_;
87};
88
89template <typename Sequence>
90template <typename T>
91size_t LongestCommonSubsequence<Sequence>::Get(
92 std::function<bool(T src_elem, T dst_elem)> match,
93 DiffMatch* src_match_result, DiffMatch* dst_match_result) {
94 size_t best_match_length = CalculateLCS(0, 0, match);
95 RetrieveMatch(src_match_result, dst_match_result);
96 return best_match_length;
97}
98
99template <typename Sequence>
100template <typename T>
101size_t LongestCommonSubsequence<Sequence>::CalculateLCS(
102 size_t src_start, size_t dst_start,
103 std::function<bool(T src_elem, T dst_elem)> match) {
104 // The LCS algorithm is simple. Given sequences s and d, with a:b depicting a
105 // range in python syntax:
106 //
107 // lcs(s[i:], d[j:]) =
108 // lcs(s[i+1:], d[j+1:]) + 1 if s[i] == d[j]
109 // max(lcs(s[i+1:], d[j:]), lcs(s[i:], d[j+1:])) o.w.
110 //
111 // Once the LCS table is filled according to the above, it can be walked and
112 // the best match retrieved.
113 //
114 // This is a recursive function with memoization, which avoids filling table
115 // entries where unnecessary. This makes the best case O(N) instead of
116 // O(N^2).
117
118 // To avoid unnecessary recursion on long sequences, process a whole strip of
119 // matching elements in one go.
120 size_t src_cur = src_start;
121 size_t dst_cur = dst_start;
122 while (IsInBound(src_cur, dst_cur) && !IsCalculated(src_cur, dst_cur) &&
123 match(src_[src_cur], dst_[dst_cur])) {
124 ++src_cur;
125 ++dst_cur;
126 }
127
128 // We've reached a pair of elements that don't match. Recursively determine
129 // which one should be left unmatched.
130 size_t best_match_length = 0;
131 if (IsInBound(src_cur, dst_cur)) {
132 if (IsCalculated(src_cur, dst_cur)) {
133 best_match_length = GetMemoizedLength(src_cur, dst_cur);
134 } else {
135 best_match_length = std::max(CalculateLCS(src_cur + 1, dst_cur, match),
136 CalculateLCS(src_cur, dst_cur + 1, match));
137
138 // Fill the table with this information
139 DiffMatchEntry& entry = table_[src_cur][dst_cur];
140 assert(!entry.valid);
141 entry.best_match_length = best_match_length;
142 entry.valid = true;
143 }
144 }
145
146 // Go over the matched strip and update the table as well.
147 assert(src_cur - src_start == dst_cur - dst_start);
148 size_t contiguous_match_len = src_cur - src_start;
149
150 for (size_t i = 0; i < contiguous_match_len; ++i) {
151 --src_cur;
152 --dst_cur;
153 assert(IsInBound(src_cur, dst_cur));
154
155 DiffMatchEntry& entry = table_[src_cur][dst_cur];
156 assert(!entry.valid);
157 entry.best_match_length = ++best_match_length;
158 entry.matched = true;
159 entry.valid = true;
160 }
161
162 return best_match_length;
163}
164
165template <typename Sequence>
166void LongestCommonSubsequence<Sequence>::RetrieveMatch(
167 DiffMatch* src_match_result, DiffMatch* dst_match_result) {
168 src_match_result->clear();
169 dst_match_result->clear();
170
171 src_match_result->resize(src_.size(), false);
172 dst_match_result->resize(dst_.size(), false);
173
174 size_t src_cur = 0;
175 size_t dst_cur = 0;
176 while (IsInBound(src_cur, dst_cur)) {
177 if (IsMatched(src_cur, dst_cur)) {
178 (*src_match_result)[src_cur++] = true;
179 (*dst_match_result)[dst_cur++] = true;
180 continue;
181 }
182
183 if (GetMemoizedLength(src_cur + 1, dst_cur) >=
184 GetMemoizedLength(src_cur, dst_cur + 1)) {
185 ++src_cur;
186 } else {
187 ++dst_cur;
188 }
189 }
190}
191
192} // namespace diff
193} // namespace spvtools
194
195#endif // SOURCE_DIFF_LCS_H_