spirv-diff: Refactor instruction grouping and matching (#4760)

In preparation for supporting OpTypeForwardPointer, which adds more
usages like this.  This change refactors common code used to group
instructions and match the groups.
diff --git a/source/diff/diff.cpp b/source/diff/diff.cpp
index 12172bf..4a15ef7 100644
--- a/source/diff/diff.cpp
+++ b/source/diff/diff.cpp
@@ -40,8 +40,8 @@
 // A list of ids with some similar property, for example functions with the same
 // name.
 using IdGroup = std::vector<uint32_t>;
-// A map of function names to function ids with the same name.  This is an
-// ordered map so different implementations produce identical results.
+// A map of names to ids with the same name.  This is an ordered map so
+// different implementations produce identical results.
 using IdGroupMapByName = std::map<std::string, IdGroup>;
 using IdGroupMapByTypeId = std::map<uint32_t, IdGroup>;
 
@@ -268,7 +268,7 @@
   // Helper functions that match ids between src and dst
   void PoolPotentialIds(
       opt::IteratorRange<opt::Module::const_inst_iterator> section,
-      std::vector<uint32_t>& ids,
+      std::vector<uint32_t>& ids, bool is_src,
       std::function<bool(const opt::Instruction&)> filter,
       std::function<uint32_t(const opt::Instruction&)> get_id);
   void MatchIds(
@@ -292,6 +292,42 @@
       opt::IteratorRange<opt::Module::const_inst_iterator> src_insts,
       opt::IteratorRange<opt::Module::const_inst_iterator> dst_insts);
 
+  // Get various properties from an id.  These Helper functions are passed to
+  // `GroupIds` and `GroupIdsAndMatch` below (as the `get_group` argument).
+  uint32_t GroupIdsHelperGetTypeId(const IdInstructions& id_to, uint32_t id);
+
+  // Given a list of ids, groups them based on some value.  The `get_group`
+  // function extracts a piece of information corresponding to each id, and the
+  // ids are bucketed based on that (and output in `groups`).  This is useful to
+  // attempt to match ids between src and dst only when said property is
+  // identical.
+  template <typename T>
+  void GroupIds(const IdGroup& ids, bool is_src, std::map<T, IdGroup>* groups,
+                T (Differ::*get_group)(const IdInstructions&, uint32_t));
+
+  // Calls GroupIds to bucket ids in src and dst based on a property returned by
+  // `get_group`.  This function then calls `match_group` for each bucket (i.e.
+  // "group") with identical values for said property.
+  //
+  // For example, say src and dst ids have the following properties
+  // correspondingly:
+  //
+  // - src ids' properties: {id0: A, id1: A, id2: B, id3: C, id4: B}
+  // - dst ids' properties: {id0': B, id1': C, id2': B, id3': D, id4': B}
+  //
+  // Then `match_group` is called 2 times:
+  //
+  // - Once with: ([id2, id4], [id0', id2', id4']) corresponding to B
+  // - Once with: ([id3], [id2']) corresponding to C
+  //
+  // Ids corresponding to A and D cannot match based on this property.
+  template <typename T>
+  void GroupIdsAndMatch(
+      const IdGroup& src_ids, const IdGroup& dst_ids, T invalid_group_key,
+      T (Differ::*get_group)(const IdInstructions&, uint32_t),
+      std::function<void(const IdGroup& src_group, const IdGroup& dst_group)>
+          match_group);
+
   // Helper functions that determine if two instructions match
   bool DoIdsMatch(uint32_t src_id, uint32_t dst_id);
   bool DoesOperandMatch(const opt::Operand& src_operand,
@@ -335,14 +371,6 @@
                          FunctionInstMap* function_insts);
   void GetFunctionHeaderInstructions(const opt::Module* module,
                                      FunctionInstMap* function_insts);
-  void GroupIdsByName(const IdGroup& functions, bool is_src,
-                      IdGroupMapByName* groups);
-  void GroupIdsByTypeId(const IdGroup& functions, bool is_src,
-                        IdGroupMapByTypeId* groups);
-  template <typename T>
-  void GroupIds(const IdGroup& functions, bool is_src,
-                std::map<T, IdGroup>* groups,
-                std::function<T(const IdInstructions, uint32_t)> get_group);
   void BestEffortMatchFunctions(const IdGroup& src_func_ids,
                                 const IdGroup& dst_func_ids,
                                 const FunctionInstMap& src_func_insts,
@@ -374,14 +402,17 @@
   uint32_t GetConstantUint(const IdInstructions& id_to, uint32_t constant_id);
   SpvExecutionModel GetExecutionModel(const opt::Module* module,
                                       uint32_t entry_point_id);
+  // Get the OpName associated with an id
   std::string GetName(const IdInstructions& id_to, uint32_t id, bool* has_name);
-  std::string GetFunctionName(const IdInstructions& id_to, uint32_t id);
+  // Get the OpName associated with an id, with argument types stripped for
+  // functions.  Some tools don't encode function argument types in the OpName
+  // string, and this improves diff between SPIR-V from those tools and others.
+  std::string GetSanitizedName(const IdInstructions& id_to, uint32_t id);
   uint32_t GetVarTypeId(const IdInstructions& id_to, uint32_t var_id,
                         SpvStorageClass* storage_class);
   bool GetDecorationValue(const IdInstructions& id_to, uint32_t id,
                           SpvDecoration decoration, uint32_t* decoration_value);
   bool IsIntType(const IdInstructions& id_to, uint32_t type_id);
-  // bool IsUintType(const IdInstructions& id_to, uint32_t type_id);
   bool IsFloatType(const IdInstructions& id_to, uint32_t type_id);
   bool IsConstantUint(const IdInstructions& id_to, uint32_t id);
   bool IsVariable(const IdInstructions& id_to, uint32_t pointer_id);
@@ -548,18 +579,27 @@
 
 void Differ::PoolPotentialIds(
     opt::IteratorRange<opt::Module::const_inst_iterator> section,
-    std::vector<uint32_t>& ids,
+    std::vector<uint32_t>& ids, bool is_src,
     std::function<bool(const opt::Instruction&)> filter,
     std::function<uint32_t(const opt::Instruction&)> get_id) {
   for (const opt::Instruction& inst : section) {
     if (!filter(inst)) {
       continue;
     }
+
     uint32_t result_id = get_id(inst);
     assert(result_id != 0);
 
     assert(std::find(ids.begin(), ids.end(), result_id) == ids.end());
 
+    // Don't include ids that are already matched, for example through
+    // OpTypeForwardPointer.
+    const bool is_matched = is_src ? id_map_.IsSrcMapped(result_id)
+                                   : id_map_.IsDstMapped(result_id);
+    if (is_matched) {
+      continue;
+    }
+
     ids.push_back(result_id);
   }
 }
@@ -748,6 +788,62 @@
   }
 }
 
+uint32_t Differ::GroupIdsHelperGetTypeId(const IdInstructions& id_to,
+                                         uint32_t id) {
+  return GetInst(id_to, id)->type_id();
+}
+
+template <typename T>
+void Differ::GroupIds(const IdGroup& ids, bool is_src,
+                      std::map<T, IdGroup>* groups,
+                      T (Differ::*get_group)(const IdInstructions&, uint32_t)) {
+  assert(groups->empty());
+
+  const IdInstructions& id_to = is_src ? src_id_to_ : dst_id_to_;
+
+  for (const uint32_t id : ids) {
+    // Don't include ids that are already matched, for example through
+    // OpEntryPoint.
+    const bool is_matched =
+        is_src ? id_map_.IsSrcMapped(id) : id_map_.IsDstMapped(id);
+    if (is_matched) {
+      continue;
+    }
+
+    T group = (this->*get_group)(id_to, id);
+    (*groups)[group].push_back(id);
+  }
+}
+
+template <typename T>
+void Differ::GroupIdsAndMatch(
+    const IdGroup& src_ids, const IdGroup& dst_ids, T invalid_group_key,
+    T (Differ::*get_group)(const IdInstructions&, uint32_t),
+    std::function<void(const IdGroup& src_group, const IdGroup& dst_group)>
+        match_group) {
+  // Group the ids based on a key (get_group)
+  std::map<T, IdGroup> src_groups;
+  std::map<T, IdGroup> dst_groups;
+
+  GroupIds<T>(src_ids, true, &src_groups, get_group);
+  GroupIds<T>(dst_ids, false, &dst_groups, get_group);
+
+  // Iterate over the groups, and match those with identical keys
+  for (const auto& iter : src_groups) {
+    const T& key = iter.first;
+    const IdGroup& src_group = iter.second;
+
+    if (key == invalid_group_key) {
+      continue;
+    }
+
+    const IdGroup& dst_group = dst_groups[key];
+
+    // Let the caller match the groups as appropriate.
+    match_group(src_group, dst_group);
+  }
+}
+
 bool Differ::DoIdsMatch(uint32_t src_id, uint32_t dst_id) {
   assert(dst_id != 0);
   return id_map_.MappedDstId(src_id) == dst_id;
@@ -1319,28 +1415,6 @@
   }
 }
 
-template <typename T>
-void Differ::GroupIds(
-    const IdGroup& functions, bool is_src, std::map<T, IdGroup>* groups,
-    std::function<T(const IdInstructions, uint32_t)> get_group) {
-  assert(groups->empty());
-
-  const IdInstructions& id_to = is_src ? src_id_to_ : dst_id_to_;
-
-  for (const uint32_t func_id : functions) {
-    // Don't include functions that are already matched, for example through
-    // OpEntryPoint.
-    const bool is_matched =
-        is_src ? id_map_.IsSrcMapped(func_id) : id_map_.IsDstMapped(func_id);
-    if (is_matched) {
-      continue;
-    }
-
-    T group = get_group(id_to, func_id);
-    (*groups)[group].push_back(func_id);
-  }
-}
-
 void Differ::BestEffortMatchFunctions(const IdGroup& src_func_ids,
                                       const IdGroup& dst_func_ids,
                                       const FunctionInstMap& src_func_insts,
@@ -1361,7 +1435,7 @@
     if (id_map_.IsSrcMapped(src_func_id)) {
       continue;
     }
-    const std::string src_name = GetFunctionName(src_id_to_, src_func_id);
+    const std::string src_name = GetSanitizedName(src_id_to_, src_func_id);
 
     for (const uint32_t dst_func_id : dst_func_ids) {
       if (id_map_.IsDstMapped(dst_func_id)) {
@@ -1369,7 +1443,7 @@
       }
 
       // Don't match functions that are named, but the names are different.
-      const std::string dst_name = GetFunctionName(dst_id_to_, dst_func_id);
+      const std::string dst_name = GetSanitizedName(dst_id_to_, dst_func_id);
       if (src_name != "" && dst_name != "" && src_name != dst_name) {
         continue;
       }
@@ -1406,22 +1480,6 @@
   }
 }
 
-void Differ::GroupIdsByName(const IdGroup& functions, bool is_src,
-                            IdGroupMapByName* groups) {
-  GroupIds<std::string>(functions, is_src, groups,
-                        [this](const IdInstructions& id_to, uint32_t func_id) {
-                          return GetFunctionName(id_to, func_id);
-                        });
-}
-
-void Differ::GroupIdsByTypeId(const IdGroup& functions, bool is_src,
-                              IdGroupMapByTypeId* groups) {
-  GroupIds<uint32_t>(functions, is_src, groups,
-                     [this](const IdInstructions& id_to, uint32_t func_id) {
-                       return GetInst(id_to, func_id)->type_id();
-                     });
-}
-
 void Differ::MatchFunctionParamIds(const opt::Function* src_func,
                                    const opt::Function* dst_func) {
   IdGroup src_params;
@@ -1437,52 +1495,33 @@
       },
       false);
 
-  IdGroupMapByName src_param_groups;
-  IdGroupMapByName dst_param_groups;
+  GroupIdsAndMatch<std::string>(
+      src_params, dst_params, "", &Differ::GetSanitizedName,
+      [this](const IdGroup& src_group, const IdGroup& dst_group) {
 
-  GroupIdsByName(src_params, true, &src_param_groups);
-  GroupIdsByName(dst_params, false, &dst_param_groups);
-
-  // Match parameters with identical names.
-  for (const auto& src_param_group : src_param_groups) {
-    const std::string& name = src_param_group.first;
-    const IdGroup& src_group = src_param_group.second;
-
-    if (name == "") {
-      continue;
-    }
-
-    const IdGroup& dst_group = dst_param_groups[name];
-
-    // There shouldn't be two parameters with the same name, so the ids should
-    // match. There is nothing restricting the SPIR-V however to have two
-    // parameters with the same name, so be resilient against that.
-    if (src_group.size() == 1 && dst_group.size() == 1) {
-      id_map_.MapIds(src_group[0], dst_group[0]);
-    }
-  }
+        // There shouldn't be two parameters with the same name, so the ids
+        // should match. There is nothing restricting the SPIR-V however to have
+        // two parameters with the same name, so be resilient against that.
+        if (src_group.size() == 1 && dst_group.size() == 1) {
+          id_map_.MapIds(src_group[0], dst_group[0]);
+        }
+      });
 
   // Then match the parameters by their type.  If there are multiple of them,
   // match them by their order.
-  IdGroupMapByTypeId src_param_groups_by_type_id;
-  IdGroupMapByTypeId dst_param_groups_by_type_id;
+  GroupIdsAndMatch<uint32_t>(
+      src_params, dst_params, 0, &Differ::GroupIdsHelperGetTypeId,
+      [this](const IdGroup& src_group_by_type_id,
+             const IdGroup& dst_group_by_type_id) {
 
-  GroupIdsByTypeId(src_params, true, &src_param_groups_by_type_id);
-  GroupIdsByTypeId(dst_params, false, &dst_param_groups_by_type_id);
+        const size_t shared_param_count =
+            std::min(src_group_by_type_id.size(), dst_group_by_type_id.size());
 
-  for (const auto& src_param_group_by_type_id : src_param_groups_by_type_id) {
-    const uint32_t type_id = src_param_group_by_type_id.first;
-    const IdGroup& src_group_by_type_id = src_param_group_by_type_id.second;
-    const IdGroup& dst_group_by_type_id = dst_param_groups_by_type_id[type_id];
-
-    const size_t shared_param_count =
-        std::min(src_group_by_type_id.size(), dst_group_by_type_id.size());
-
-    for (size_t param_index = 0; param_index < shared_param_count;
-         ++param_index) {
-      id_map_.MapIds(src_group_by_type_id[0], dst_group_by_type_id[0]);
-    }
-  }
+        for (size_t param_index = 0; param_index < shared_param_count;
+             ++param_index) {
+          id_map_.MapIds(src_group_by_type_id[0], dst_group_by_type_id[0]);
+        }
+      });
 }
 
 float Differ::MatchFunctionBodies(const InstructionList& src_body,
@@ -1626,7 +1665,7 @@
   return "";
 }
 
-std::string Differ::GetFunctionName(const IdInstructions& id_to, uint32_t id) {
+std::string Differ::GetSanitizedName(const IdInstructions& id_to, uint32_t id) {
   bool has_name = false;
   std::string name = GetName(id_to, id, &has_name);
 
@@ -1634,7 +1673,7 @@
     return "";
   }
 
-  // Remove args from the name
+  // Remove args from the name, in case this is a function name
   return name.substr(0, name.find('('));
 }
 
@@ -1672,19 +1711,8 @@
 
 bool Differ::IsIntType(const IdInstructions& id_to, uint32_t type_id) {
   return IsOp(id_to, type_id, SpvOpTypeInt);
-#if 0
-  const opt::Instruction *type_inst = GetInst(id_to, type_id);
-  return type_inst->opcode() == SpvOpTypeInt && type_inst->GetInOperand(1).words[0] != 0;
-#endif
 }
 
-#if 0
-bool Differ::IsUintType(const IdInstructions& id_to, uint32_t type_id) {
-  const opt::Instruction *type_inst = GetInst(id_to, type_id);
-  return type_inst->opcode() == SpvOpTypeInt && type_inst->GetInOperand(1).words[0] == 0;
-}
-#endif
-
 bool Differ::IsFloatType(const IdInstructions& id_to, uint32_t type_id) {
   return IsOp(id_to, type_id, SpvOpTypeFloat);
 }
@@ -1853,9 +1881,9 @@
   };
   auto accept_all = [](const opt::Instruction&) { return true; };
 
-  PoolPotentialIds(src_->ext_inst_imports(), potential_id_map.src_ids,
+  PoolPotentialIds(src_->ext_inst_imports(), potential_id_map.src_ids, true,
                    accept_all, get_result_id);
-  PoolPotentialIds(dst_->ext_inst_imports(), potential_id_map.dst_ids,
+  PoolPotentialIds(dst_->ext_inst_imports(), potential_id_map.dst_ids, false,
                    accept_all, get_result_id);
 
   // Then match the ids.
@@ -1949,9 +1977,9 @@
     return spvOpcodeGeneratesType(inst.opcode());
   };
 
-  PoolPotentialIds(src_->types_values(), potential_id_map.src_ids,
+  PoolPotentialIds(src_->types_values(), potential_id_map.src_ids, true,
                    accept_type_ops, get_result_id);
-  PoolPotentialIds(dst_->types_values(), potential_id_map.dst_ids,
+  PoolPotentialIds(dst_->types_values(), potential_id_map.dst_ids, false,
                    accept_type_ops, get_result_id);
 
   // Then match the ids.  Start with exact matches, then match the leftover with
@@ -2036,9 +2064,9 @@
     return spvOpcodeIsConstant(inst.opcode());
   };
 
-  PoolPotentialIds(src_->types_values(), potential_id_map.src_ids,
+  PoolPotentialIds(src_->types_values(), potential_id_map.src_ids, true,
                    accept_type_ops, get_result_id);
-  PoolPotentialIds(dst_->types_values(), potential_id_map.dst_ids,
+  PoolPotentialIds(dst_->types_values(), potential_id_map.dst_ids, false,
                    accept_type_ops, get_result_id);
 
   // Then match the ids.  Constants are matched exactly, except for float types
@@ -2115,9 +2143,9 @@
     return inst.opcode() == SpvOpVariable;
   };
 
-  PoolPotentialIds(src_->types_values(), potential_id_map.src_ids,
+  PoolPotentialIds(src_->types_values(), potential_id_map.src_ids, true,
                    accept_type_ops, get_result_id);
-  PoolPotentialIds(dst_->types_values(), potential_id_map.dst_ids,
+  PoolPotentialIds(dst_->types_values(), potential_id_map.dst_ids, false,
                    accept_type_ops, get_result_id);
 
   // Then match the ids.  Start with exact matches, then match the leftover with
@@ -2148,49 +2176,31 @@
   }
 
   // Base the matching of functions on debug info when available.
-  IdGroupMapByName src_func_groups;
-  IdGroupMapByName dst_func_groups;
+  GroupIdsAndMatch<std::string>(
+      src_func_ids, dst_func_ids, "", &Differ::GetSanitizedName,
+      [this](const IdGroup& src_group, const IdGroup& dst_group) {
 
-  GroupIdsByName(src_func_ids, true, &src_func_groups);
-  GroupIdsByName(dst_func_ids, false, &dst_func_groups);
+        // If there is a single function with this name in src and dst, it's a
+        // definite match.
+        if (src_group.size() == 1 && dst_group.size() == 1) {
+          id_map_.MapIds(src_group[0], dst_group[0]);
+          return;
+        }
 
-  // Match functions with identical names.
-  for (const auto& src_func_group : src_func_groups) {
-    const std::string& name = src_func_group.first;
-    const IdGroup& src_group = src_func_group.second;
+        // If there are multiple functions with the same name, group them by
+        // type, and match only if the types match (and are unique).
+        GroupIdsAndMatch<uint32_t>(src_group, dst_group, 0,
+                                   &Differ::GroupIdsHelperGetTypeId,
+                                   [this](const IdGroup& src_group_by_type_id,
+                                          const IdGroup& dst_group_by_type_id) {
 
-    if (name == "") {
-      continue;
-    }
-
-    const IdGroup& dst_group = dst_func_groups[name];
-
-    // If there is a single function with this name in src and dst, it's a
-    // definite match.
-    if (src_group.size() == 1 && dst_group.size() == 1) {
-      id_map_.MapIds(src_group[0], dst_group[0]);
-      continue;
-    }
-
-    // If there are multiple functions with the same name, group them by type,
-    // and match only if the types match (and are unique).
-    IdGroupMapByTypeId src_func_groups_by_type_id;
-    IdGroupMapByTypeId dst_func_groups_by_type_id;
-
-    GroupIdsByTypeId(src_group, true, &src_func_groups_by_type_id);
-    GroupIdsByTypeId(dst_group, false, &dst_func_groups_by_type_id);
-
-    for (const auto& src_func_group_by_type_id : src_func_groups_by_type_id) {
-      const uint32_t type_id = src_func_group_by_type_id.first;
-      const IdGroup& src_group_by_type_id = src_func_group_by_type_id.second;
-      const IdGroup& dst_group_by_type_id = dst_func_groups_by_type_id[type_id];
-
-      if (src_group_by_type_id.size() == 1 &&
-          dst_group_by_type_id.size() == 1) {
-        id_map_.MapIds(src_group_by_type_id[0], dst_group_by_type_id[0]);
-      }
-    }
-  }
+                                     if (src_group_by_type_id.size() == 1 &&
+                                         dst_group_by_type_id.size() == 1) {
+                                       id_map_.MapIds(src_group_by_type_id[0],
+                                                      dst_group_by_type_id[0]);
+                                     }
+                                   });
+      });
 
   // Any functions that are left are pooled together and matched as if unnamed,
   // with the only exception that two functions with mismatching names are not
@@ -2224,20 +2234,14 @@
   }
 
   // Best effort match functions with matching type.
-  IdGroupMapByTypeId src_func_groups_by_type_id;
-  IdGroupMapByTypeId dst_func_groups_by_type_id;
+  GroupIdsAndMatch<uint32_t>(
+      src_func_ids, dst_func_ids, 0, &Differ::GroupIdsHelperGetTypeId,
+      [this](const IdGroup& src_group_by_type_id,
+             const IdGroup& dst_group_by_type_id) {
 
-  GroupIdsByTypeId(src_func_ids, true, &src_func_groups_by_type_id);
-  GroupIdsByTypeId(dst_func_ids, false, &dst_func_groups_by_type_id);
-
-  for (const auto& src_func_group_by_type_id : src_func_groups_by_type_id) {
-    const uint32_t type_id = src_func_group_by_type_id.first;
-    const IdGroup& src_group_by_type_id = src_func_group_by_type_id.second;
-    const IdGroup& dst_group_by_type_id = dst_func_groups_by_type_id[type_id];
-
-    BestEffortMatchFunctions(src_group_by_type_id, dst_group_by_type_id,
-                             src_func_insts_, dst_func_insts_);
-  }
+        BestEffortMatchFunctions(src_group_by_type_id, dst_group_by_type_id,
+                                 src_func_insts_, dst_func_insts_);
+      });
 
   // Any function that's left, best effort match them.
   BestEffortMatchFunctions(src_func_ids, dst_func_ids, src_func_insts_,