spirv-diff: Refactor instruction grouping and matching (#4760)
In preparation for supporting OpTypeForwardPointer, which adds more
usages like this. This change refactors common code used to group
instructions and match the groups.
diff --git a/source/diff/diff.cpp b/source/diff/diff.cpp
index 12172bf..4a15ef7 100644
--- a/source/diff/diff.cpp
+++ b/source/diff/diff.cpp
@@ -40,8 +40,8 @@
// A list of ids with some similar property, for example functions with the same
// name.
using IdGroup = std::vector<uint32_t>;
-// A map of function names to function ids with the same name. This is an
-// ordered map so different implementations produce identical results.
+// A map of names to ids with the same name. This is an ordered map so
+// different implementations produce identical results.
using IdGroupMapByName = std::map<std::string, IdGroup>;
using IdGroupMapByTypeId = std::map<uint32_t, IdGroup>;
@@ -268,7 +268,7 @@
// Helper functions that match ids between src and dst
void PoolPotentialIds(
opt::IteratorRange<opt::Module::const_inst_iterator> section,
- std::vector<uint32_t>& ids,
+ std::vector<uint32_t>& ids, bool is_src,
std::function<bool(const opt::Instruction&)> filter,
std::function<uint32_t(const opt::Instruction&)> get_id);
void MatchIds(
@@ -292,6 +292,42 @@
opt::IteratorRange<opt::Module::const_inst_iterator> src_insts,
opt::IteratorRange<opt::Module::const_inst_iterator> dst_insts);
+ // Get various properties from an id. These Helper functions are passed to
+ // `GroupIds` and `GroupIdsAndMatch` below (as the `get_group` argument).
+ uint32_t GroupIdsHelperGetTypeId(const IdInstructions& id_to, uint32_t id);
+
+ // Given a list of ids, groups them based on some value. The `get_group`
+ // function extracts a piece of information corresponding to each id, and the
+ // ids are bucketed based on that (and output in `groups`). This is useful to
+ // attempt to match ids between src and dst only when said property is
+ // identical.
+ template <typename T>
+ void GroupIds(const IdGroup& ids, bool is_src, std::map<T, IdGroup>* groups,
+ T (Differ::*get_group)(const IdInstructions&, uint32_t));
+
+ // Calls GroupIds to bucket ids in src and dst based on a property returned by
+ // `get_group`. This function then calls `match_group` for each bucket (i.e.
+ // "group") with identical values for said property.
+ //
+ // For example, say src and dst ids have the following properties
+ // correspondingly:
+ //
+ // - src ids' properties: {id0: A, id1: A, id2: B, id3: C, id4: B}
+ // - dst ids' properties: {id0': B, id1': C, id2': B, id3': D, id4': B}
+ //
+ // Then `match_group` is called 2 times:
+ //
+ // - Once with: ([id2, id4], [id0', id2', id4']) corresponding to B
+ // - Once with: ([id3], [id2']) corresponding to C
+ //
+ // Ids corresponding to A and D cannot match based on this property.
+ template <typename T>
+ void GroupIdsAndMatch(
+ const IdGroup& src_ids, const IdGroup& dst_ids, T invalid_group_key,
+ T (Differ::*get_group)(const IdInstructions&, uint32_t),
+ std::function<void(const IdGroup& src_group, const IdGroup& dst_group)>
+ match_group);
+
// Helper functions that determine if two instructions match
bool DoIdsMatch(uint32_t src_id, uint32_t dst_id);
bool DoesOperandMatch(const opt::Operand& src_operand,
@@ -335,14 +371,6 @@
FunctionInstMap* function_insts);
void GetFunctionHeaderInstructions(const opt::Module* module,
FunctionInstMap* function_insts);
- void GroupIdsByName(const IdGroup& functions, bool is_src,
- IdGroupMapByName* groups);
- void GroupIdsByTypeId(const IdGroup& functions, bool is_src,
- IdGroupMapByTypeId* groups);
- template <typename T>
- void GroupIds(const IdGroup& functions, bool is_src,
- std::map<T, IdGroup>* groups,
- std::function<T(const IdInstructions, uint32_t)> get_group);
void BestEffortMatchFunctions(const IdGroup& src_func_ids,
const IdGroup& dst_func_ids,
const FunctionInstMap& src_func_insts,
@@ -374,14 +402,17 @@
uint32_t GetConstantUint(const IdInstructions& id_to, uint32_t constant_id);
SpvExecutionModel GetExecutionModel(const opt::Module* module,
uint32_t entry_point_id);
+ // Get the OpName associated with an id
std::string GetName(const IdInstructions& id_to, uint32_t id, bool* has_name);
- std::string GetFunctionName(const IdInstructions& id_to, uint32_t id);
+ // Get the OpName associated with an id, with argument types stripped for
+ // functions. Some tools don't encode function argument types in the OpName
+ // string, and this improves diff between SPIR-V from those tools and others.
+ std::string GetSanitizedName(const IdInstructions& id_to, uint32_t id);
uint32_t GetVarTypeId(const IdInstructions& id_to, uint32_t var_id,
SpvStorageClass* storage_class);
bool GetDecorationValue(const IdInstructions& id_to, uint32_t id,
SpvDecoration decoration, uint32_t* decoration_value);
bool IsIntType(const IdInstructions& id_to, uint32_t type_id);
- // bool IsUintType(const IdInstructions& id_to, uint32_t type_id);
bool IsFloatType(const IdInstructions& id_to, uint32_t type_id);
bool IsConstantUint(const IdInstructions& id_to, uint32_t id);
bool IsVariable(const IdInstructions& id_to, uint32_t pointer_id);
@@ -548,18 +579,27 @@
void Differ::PoolPotentialIds(
opt::IteratorRange<opt::Module::const_inst_iterator> section,
- std::vector<uint32_t>& ids,
+ std::vector<uint32_t>& ids, bool is_src,
std::function<bool(const opt::Instruction&)> filter,
std::function<uint32_t(const opt::Instruction&)> get_id) {
for (const opt::Instruction& inst : section) {
if (!filter(inst)) {
continue;
}
+
uint32_t result_id = get_id(inst);
assert(result_id != 0);
assert(std::find(ids.begin(), ids.end(), result_id) == ids.end());
+ // Don't include ids that are already matched, for example through
+ // OpTypeForwardPointer.
+ const bool is_matched = is_src ? id_map_.IsSrcMapped(result_id)
+ : id_map_.IsDstMapped(result_id);
+ if (is_matched) {
+ continue;
+ }
+
ids.push_back(result_id);
}
}
@@ -748,6 +788,62 @@
}
}
+uint32_t Differ::GroupIdsHelperGetTypeId(const IdInstructions& id_to,
+ uint32_t id) {
+ return GetInst(id_to, id)->type_id();
+}
+
+template <typename T>
+void Differ::GroupIds(const IdGroup& ids, bool is_src,
+ std::map<T, IdGroup>* groups,
+ T (Differ::*get_group)(const IdInstructions&, uint32_t)) {
+ assert(groups->empty());
+
+ const IdInstructions& id_to = is_src ? src_id_to_ : dst_id_to_;
+
+ for (const uint32_t id : ids) {
+ // Don't include ids that are already matched, for example through
+ // OpEntryPoint.
+ const bool is_matched =
+ is_src ? id_map_.IsSrcMapped(id) : id_map_.IsDstMapped(id);
+ if (is_matched) {
+ continue;
+ }
+
+ T group = (this->*get_group)(id_to, id);
+ (*groups)[group].push_back(id);
+ }
+}
+
+template <typename T>
+void Differ::GroupIdsAndMatch(
+ const IdGroup& src_ids, const IdGroup& dst_ids, T invalid_group_key,
+ T (Differ::*get_group)(const IdInstructions&, uint32_t),
+ std::function<void(const IdGroup& src_group, const IdGroup& dst_group)>
+ match_group) {
+ // Group the ids based on a key (get_group)
+ std::map<T, IdGroup> src_groups;
+ std::map<T, IdGroup> dst_groups;
+
+ GroupIds<T>(src_ids, true, &src_groups, get_group);
+ GroupIds<T>(dst_ids, false, &dst_groups, get_group);
+
+ // Iterate over the groups, and match those with identical keys
+ for (const auto& iter : src_groups) {
+ const T& key = iter.first;
+ const IdGroup& src_group = iter.second;
+
+ if (key == invalid_group_key) {
+ continue;
+ }
+
+ const IdGroup& dst_group = dst_groups[key];
+
+ // Let the caller match the groups as appropriate.
+ match_group(src_group, dst_group);
+ }
+}
+
bool Differ::DoIdsMatch(uint32_t src_id, uint32_t dst_id) {
assert(dst_id != 0);
return id_map_.MappedDstId(src_id) == dst_id;
@@ -1319,28 +1415,6 @@
}
}
-template <typename T>
-void Differ::GroupIds(
- const IdGroup& functions, bool is_src, std::map<T, IdGroup>* groups,
- std::function<T(const IdInstructions, uint32_t)> get_group) {
- assert(groups->empty());
-
- const IdInstructions& id_to = is_src ? src_id_to_ : dst_id_to_;
-
- for (const uint32_t func_id : functions) {
- // Don't include functions that are already matched, for example through
- // OpEntryPoint.
- const bool is_matched =
- is_src ? id_map_.IsSrcMapped(func_id) : id_map_.IsDstMapped(func_id);
- if (is_matched) {
- continue;
- }
-
- T group = get_group(id_to, func_id);
- (*groups)[group].push_back(func_id);
- }
-}
-
void Differ::BestEffortMatchFunctions(const IdGroup& src_func_ids,
const IdGroup& dst_func_ids,
const FunctionInstMap& src_func_insts,
@@ -1361,7 +1435,7 @@
if (id_map_.IsSrcMapped(src_func_id)) {
continue;
}
- const std::string src_name = GetFunctionName(src_id_to_, src_func_id);
+ const std::string src_name = GetSanitizedName(src_id_to_, src_func_id);
for (const uint32_t dst_func_id : dst_func_ids) {
if (id_map_.IsDstMapped(dst_func_id)) {
@@ -1369,7 +1443,7 @@
}
// Don't match functions that are named, but the names are different.
- const std::string dst_name = GetFunctionName(dst_id_to_, dst_func_id);
+ const std::string dst_name = GetSanitizedName(dst_id_to_, dst_func_id);
if (src_name != "" && dst_name != "" && src_name != dst_name) {
continue;
}
@@ -1406,22 +1480,6 @@
}
}
-void Differ::GroupIdsByName(const IdGroup& functions, bool is_src,
- IdGroupMapByName* groups) {
- GroupIds<std::string>(functions, is_src, groups,
- [this](const IdInstructions& id_to, uint32_t func_id) {
- return GetFunctionName(id_to, func_id);
- });
-}
-
-void Differ::GroupIdsByTypeId(const IdGroup& functions, bool is_src,
- IdGroupMapByTypeId* groups) {
- GroupIds<uint32_t>(functions, is_src, groups,
- [this](const IdInstructions& id_to, uint32_t func_id) {
- return GetInst(id_to, func_id)->type_id();
- });
-}
-
void Differ::MatchFunctionParamIds(const opt::Function* src_func,
const opt::Function* dst_func) {
IdGroup src_params;
@@ -1437,52 +1495,33 @@
},
false);
- IdGroupMapByName src_param_groups;
- IdGroupMapByName dst_param_groups;
+ GroupIdsAndMatch<std::string>(
+ src_params, dst_params, "", &Differ::GetSanitizedName,
+ [this](const IdGroup& src_group, const IdGroup& dst_group) {
- GroupIdsByName(src_params, true, &src_param_groups);
- GroupIdsByName(dst_params, false, &dst_param_groups);
-
- // Match parameters with identical names.
- for (const auto& src_param_group : src_param_groups) {
- const std::string& name = src_param_group.first;
- const IdGroup& src_group = src_param_group.second;
-
- if (name == "") {
- continue;
- }
-
- const IdGroup& dst_group = dst_param_groups[name];
-
- // There shouldn't be two parameters with the same name, so the ids should
- // match. There is nothing restricting the SPIR-V however to have two
- // parameters with the same name, so be resilient against that.
- if (src_group.size() == 1 && dst_group.size() == 1) {
- id_map_.MapIds(src_group[0], dst_group[0]);
- }
- }
+ // There shouldn't be two parameters with the same name, so the ids
+ // should match. There is nothing restricting the SPIR-V however to have
+ // two parameters with the same name, so be resilient against that.
+ if (src_group.size() == 1 && dst_group.size() == 1) {
+ id_map_.MapIds(src_group[0], dst_group[0]);
+ }
+ });
// Then match the parameters by their type. If there are multiple of them,
// match them by their order.
- IdGroupMapByTypeId src_param_groups_by_type_id;
- IdGroupMapByTypeId dst_param_groups_by_type_id;
+ GroupIdsAndMatch<uint32_t>(
+ src_params, dst_params, 0, &Differ::GroupIdsHelperGetTypeId,
+ [this](const IdGroup& src_group_by_type_id,
+ const IdGroup& dst_group_by_type_id) {
- GroupIdsByTypeId(src_params, true, &src_param_groups_by_type_id);
- GroupIdsByTypeId(dst_params, false, &dst_param_groups_by_type_id);
+ const size_t shared_param_count =
+ std::min(src_group_by_type_id.size(), dst_group_by_type_id.size());
- for (const auto& src_param_group_by_type_id : src_param_groups_by_type_id) {
- const uint32_t type_id = src_param_group_by_type_id.first;
- const IdGroup& src_group_by_type_id = src_param_group_by_type_id.second;
- const IdGroup& dst_group_by_type_id = dst_param_groups_by_type_id[type_id];
-
- const size_t shared_param_count =
- std::min(src_group_by_type_id.size(), dst_group_by_type_id.size());
-
- for (size_t param_index = 0; param_index < shared_param_count;
- ++param_index) {
- id_map_.MapIds(src_group_by_type_id[0], dst_group_by_type_id[0]);
- }
- }
+ for (size_t param_index = 0; param_index < shared_param_count;
+ ++param_index) {
+ id_map_.MapIds(src_group_by_type_id[0], dst_group_by_type_id[0]);
+ }
+ });
}
float Differ::MatchFunctionBodies(const InstructionList& src_body,
@@ -1626,7 +1665,7 @@
return "";
}
-std::string Differ::GetFunctionName(const IdInstructions& id_to, uint32_t id) {
+std::string Differ::GetSanitizedName(const IdInstructions& id_to, uint32_t id) {
bool has_name = false;
std::string name = GetName(id_to, id, &has_name);
@@ -1634,7 +1673,7 @@
return "";
}
- // Remove args from the name
+ // Remove args from the name, in case this is a function name
return name.substr(0, name.find('('));
}
@@ -1672,19 +1711,8 @@
bool Differ::IsIntType(const IdInstructions& id_to, uint32_t type_id) {
return IsOp(id_to, type_id, SpvOpTypeInt);
-#if 0
- const opt::Instruction *type_inst = GetInst(id_to, type_id);
- return type_inst->opcode() == SpvOpTypeInt && type_inst->GetInOperand(1).words[0] != 0;
-#endif
}
-#if 0
-bool Differ::IsUintType(const IdInstructions& id_to, uint32_t type_id) {
- const opt::Instruction *type_inst = GetInst(id_to, type_id);
- return type_inst->opcode() == SpvOpTypeInt && type_inst->GetInOperand(1).words[0] == 0;
-}
-#endif
-
bool Differ::IsFloatType(const IdInstructions& id_to, uint32_t type_id) {
return IsOp(id_to, type_id, SpvOpTypeFloat);
}
@@ -1853,9 +1881,9 @@
};
auto accept_all = [](const opt::Instruction&) { return true; };
- PoolPotentialIds(src_->ext_inst_imports(), potential_id_map.src_ids,
+ PoolPotentialIds(src_->ext_inst_imports(), potential_id_map.src_ids, true,
accept_all, get_result_id);
- PoolPotentialIds(dst_->ext_inst_imports(), potential_id_map.dst_ids,
+ PoolPotentialIds(dst_->ext_inst_imports(), potential_id_map.dst_ids, false,
accept_all, get_result_id);
// Then match the ids.
@@ -1949,9 +1977,9 @@
return spvOpcodeGeneratesType(inst.opcode());
};
- PoolPotentialIds(src_->types_values(), potential_id_map.src_ids,
+ PoolPotentialIds(src_->types_values(), potential_id_map.src_ids, true,
accept_type_ops, get_result_id);
- PoolPotentialIds(dst_->types_values(), potential_id_map.dst_ids,
+ PoolPotentialIds(dst_->types_values(), potential_id_map.dst_ids, false,
accept_type_ops, get_result_id);
// Then match the ids. Start with exact matches, then match the leftover with
@@ -2036,9 +2064,9 @@
return spvOpcodeIsConstant(inst.opcode());
};
- PoolPotentialIds(src_->types_values(), potential_id_map.src_ids,
+ PoolPotentialIds(src_->types_values(), potential_id_map.src_ids, true,
accept_type_ops, get_result_id);
- PoolPotentialIds(dst_->types_values(), potential_id_map.dst_ids,
+ PoolPotentialIds(dst_->types_values(), potential_id_map.dst_ids, false,
accept_type_ops, get_result_id);
// Then match the ids. Constants are matched exactly, except for float types
@@ -2115,9 +2143,9 @@
return inst.opcode() == SpvOpVariable;
};
- PoolPotentialIds(src_->types_values(), potential_id_map.src_ids,
+ PoolPotentialIds(src_->types_values(), potential_id_map.src_ids, true,
accept_type_ops, get_result_id);
- PoolPotentialIds(dst_->types_values(), potential_id_map.dst_ids,
+ PoolPotentialIds(dst_->types_values(), potential_id_map.dst_ids, false,
accept_type_ops, get_result_id);
// Then match the ids. Start with exact matches, then match the leftover with
@@ -2148,49 +2176,31 @@
}
// Base the matching of functions on debug info when available.
- IdGroupMapByName src_func_groups;
- IdGroupMapByName dst_func_groups;
+ GroupIdsAndMatch<std::string>(
+ src_func_ids, dst_func_ids, "", &Differ::GetSanitizedName,
+ [this](const IdGroup& src_group, const IdGroup& dst_group) {
- GroupIdsByName(src_func_ids, true, &src_func_groups);
- GroupIdsByName(dst_func_ids, false, &dst_func_groups);
+ // If there is a single function with this name in src and dst, it's a
+ // definite match.
+ if (src_group.size() == 1 && dst_group.size() == 1) {
+ id_map_.MapIds(src_group[0], dst_group[0]);
+ return;
+ }
- // Match functions with identical names.
- for (const auto& src_func_group : src_func_groups) {
- const std::string& name = src_func_group.first;
- const IdGroup& src_group = src_func_group.second;
+ // If there are multiple functions with the same name, group them by
+ // type, and match only if the types match (and are unique).
+ GroupIdsAndMatch<uint32_t>(src_group, dst_group, 0,
+ &Differ::GroupIdsHelperGetTypeId,
+ [this](const IdGroup& src_group_by_type_id,
+ const IdGroup& dst_group_by_type_id) {
- if (name == "") {
- continue;
- }
-
- const IdGroup& dst_group = dst_func_groups[name];
-
- // If there is a single function with this name in src and dst, it's a
- // definite match.
- if (src_group.size() == 1 && dst_group.size() == 1) {
- id_map_.MapIds(src_group[0], dst_group[0]);
- continue;
- }
-
- // If there are multiple functions with the same name, group them by type,
- // and match only if the types match (and are unique).
- IdGroupMapByTypeId src_func_groups_by_type_id;
- IdGroupMapByTypeId dst_func_groups_by_type_id;
-
- GroupIdsByTypeId(src_group, true, &src_func_groups_by_type_id);
- GroupIdsByTypeId(dst_group, false, &dst_func_groups_by_type_id);
-
- for (const auto& src_func_group_by_type_id : src_func_groups_by_type_id) {
- const uint32_t type_id = src_func_group_by_type_id.first;
- const IdGroup& src_group_by_type_id = src_func_group_by_type_id.second;
- const IdGroup& dst_group_by_type_id = dst_func_groups_by_type_id[type_id];
-
- if (src_group_by_type_id.size() == 1 &&
- dst_group_by_type_id.size() == 1) {
- id_map_.MapIds(src_group_by_type_id[0], dst_group_by_type_id[0]);
- }
- }
- }
+ if (src_group_by_type_id.size() == 1 &&
+ dst_group_by_type_id.size() == 1) {
+ id_map_.MapIds(src_group_by_type_id[0],
+ dst_group_by_type_id[0]);
+ }
+ });
+ });
// Any functions that are left are pooled together and matched as if unnamed,
// with the only exception that two functions with mismatching names are not
@@ -2224,20 +2234,14 @@
}
// Best effort match functions with matching type.
- IdGroupMapByTypeId src_func_groups_by_type_id;
- IdGroupMapByTypeId dst_func_groups_by_type_id;
+ GroupIdsAndMatch<uint32_t>(
+ src_func_ids, dst_func_ids, 0, &Differ::GroupIdsHelperGetTypeId,
+ [this](const IdGroup& src_group_by_type_id,
+ const IdGroup& dst_group_by_type_id) {
- GroupIdsByTypeId(src_func_ids, true, &src_func_groups_by_type_id);
- GroupIdsByTypeId(dst_func_ids, false, &dst_func_groups_by_type_id);
-
- for (const auto& src_func_group_by_type_id : src_func_groups_by_type_id) {
- const uint32_t type_id = src_func_group_by_type_id.first;
- const IdGroup& src_group_by_type_id = src_func_group_by_type_id.second;
- const IdGroup& dst_group_by_type_id = dst_func_groups_by_type_id[type_id];
-
- BestEffortMatchFunctions(src_group_by_type_id, dst_group_by_type_id,
- src_func_insts_, dst_func_insts_);
- }
+ BestEffortMatchFunctions(src_group_by_type_id, dst_group_by_type_id,
+ src_func_insts_, dst_func_insts_);
+ });
// Any function that's left, best effort match them.
BestEffortMatchFunctions(src_func_ids, dst_func_ids, src_func_insts_,