blob: ff7c13704d45b73837cc3c655147fc5a412560b0 [file] [log] [blame]
David Netodd992212017-06-23 17:47:55 -04001// Copyright 2017 The Clspv Authors. All rights reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
David Neto156783e2017-07-05 15:39:41 -040015// Cluster POD kernel arguments.
16//
17// Collect plain-old-data kernel arguments and place them into a single
18// struct argument, at the end. Other arguments are pointers, and retain
19// their relative order.
20//
21// We will create a kernel function as the new entry point, and change
22// the original kernel function into a regular SPIR function. Key
23// kernel metadata is moved from the old function to the wrapper.
24// We also attach a "kernel_arg_map" metadata node to the function to
25// encode the mapping from old kernel argument to new kernel argument.
26
alan-baker6a3930b2020-05-21 10:09:11 -040027#include <algorithm>
David Netodd992212017-06-23 17:47:55 -040028#include <cassert>
David Netoc6f3ab22018-04-06 18:02:31 -040029#include <cstring>
David Netodd992212017-06-23 17:47:55 -040030
David Netoc6f3ab22018-04-06 18:02:31 -040031#include "llvm/IR/Constants.h"
32#include "llvm/IR/DerivedTypes.h"
33#include "llvm/IR/Function.h"
David Netoc6f3ab22018-04-06 18:02:31 -040034#include "llvm/IR/IRBuilder.h"
Diego Novillo3cc8d7a2019-04-10 13:30:34 -040035#include "llvm/IR/Instructions.h"
David Netoc6f3ab22018-04-06 18:02:31 -040036#include "llvm/IR/Metadata.h"
37#include "llvm/IR/Module.h"
alan-baker6a3930b2020-05-21 10:09:11 -040038#include "llvm/IR/Operator.h"
David Netoc6f3ab22018-04-06 18:02:31 -040039#include "llvm/Pass.h"
40#include "llvm/Support/CommandLine.h"
alan-baker038e9242019-04-19 22:14:41 -040041#include "llvm/Support/MathExtras.h"
David Netoc6f3ab22018-04-06 18:02:31 -040042#include "llvm/Support/raw_ostream.h"
43#include "llvm/Transforms/Utils/Cloning.h"
David Netodd992212017-06-23 17:47:55 -040044
alan-baker6a3930b2020-05-21 10:09:11 -040045#include "clspv/AddressSpace.h"
alan-baker038e9242019-04-19 22:14:41 -040046#include "clspv/Option.h"
47
David Neto4feb7a42017-10-06 17:29:42 -040048#include "ArgKind.h"
alan-bakerc4579bb2020-04-29 14:15:50 -040049#include "Constants.h"
Diego Novilloa4c44fa2019-04-11 10:56:15 -040050#include "Passes.h"
alan-baker6a3930b2020-05-21 10:09:11 -040051#include "PushConstant.h"
David Netodd992212017-06-23 17:47:55 -040052
53using namespace llvm;
54
55#define DEBUG_TYPE "clusterpodkernelargs"
56
57namespace {
alan-baker6a3930b2020-05-21 10:09:11 -040058const uint64_t kIntBytes = 4;
59
David Netodd992212017-06-23 17:47:55 -040060struct ClusterPodKernelArgumentsPass : public ModulePass {
61 static char ID;
62 ClusterPodKernelArgumentsPass() : ModulePass(ID) {}
63
64 bool runOnModule(Module &M) override;
alan-baker6a3930b2020-05-21 10:09:11 -040065
66private:
67 // Returns the type-mangled struct for global pod args. Only generates
68 // unpacked structs currently. The type conversion code does not handle
69 // packed structs propoerly. AutoPodArgsPass would also need updates to
70 // support packed structs.
71 StructType *GetTypeMangledPodArgsStruct(Module &M);
72
73 // (Re-)Declares the global push constant variable with |mangled_struct_ty|
74 // as the last member.
75 void RedeclareGlobalPushConstants(Module &M, StructType *mangled_struct_ty);
76
77 // Converts the corresponding elements of the global push constants for pod
78 // args in member |index| of |pod_struct|.
79 Value *ConvertToType(Module &M, StructType *pod_struct, unsigned index,
80 IRBuilder<> &builder);
81
82 // Builds |dst_type| from |elements|, where |elements| is a vector i32 loads.
83 Value *BuildFromElements(Module &M, IRBuilder<> &builder, Type *dst_type,
84 uint64_t base_offset, uint64_t base_index,
85 const std::vector<Value *> &elements);
David Netodd992212017-06-23 17:47:55 -040086};
David Neto48f56a42017-10-06 16:44:25 -040087
David Netodd992212017-06-23 17:47:55 -040088} // namespace
89
90char ClusterPodKernelArgumentsPass::ID = 0;
Diego Novilloa4c44fa2019-04-11 10:56:15 -040091INITIALIZE_PASS(ClusterPodKernelArgumentsPass, "ClusterPodKernelArgumentsPass",
92 "Cluster POD Kernel Arguments Pass", false, false)
David Netodd992212017-06-23 17:47:55 -040093
94namespace clspv {
95llvm::ModulePass *createClusterPodKernelArgumentsPass() {
96 return new ClusterPodKernelArgumentsPass();
97}
98} // namespace clspv
99
100bool ClusterPodKernelArgumentsPass::runOnModule(Module &M) {
101 bool Changed = false;
102 LLVMContext &Context = M.getContext();
103
104 SmallVector<Function *, 8> WorkList;
105
106 for (Function &F : M) {
107 if (F.isDeclaration() || F.getCallingConv() != CallingConv::SPIR_KERNEL) {
108 continue;
109 }
110 for (Argument &Arg : F.args()) {
111 if (!isa<PointerType>(Arg.getType())) {
112 WorkList.push_back(&F);
113 break;
114 }
115 }
116 }
117
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400118 SmallVector<CallInst *, 8> CallList;
David Netod5b3f982017-09-28 14:49:49 -0400119
alan-baker6a3930b2020-05-21 10:09:11 -0400120 // If any of the kernels call for type-mangled push constants, we need to
121 // know the right type and base offset.
122 const uint64_t global_push_constant_size = clspv::GlobalPushConstantsSize(M);
123 assert(global_push_constant_size % 16 == 0 &&
124 "Global push constants size changed");
125 auto mangled_struct_ty = GetTypeMangledPodArgsStruct(M);
126 if (mangled_struct_ty) {
127 RedeclareGlobalPushConstants(M, mangled_struct_ty);
128 }
129
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400130 for (Function *F : WorkList) {
David Netodd992212017-06-23 17:47:55 -0400131 Changed = true;
132
alan-bakerc4579bb2020-04-29 14:15:50 -0400133 auto pod_arg_impl = clspv::GetPodArgsImpl(*F);
134 auto pod_arg_kind = clspv::GetArgKindForPodArgs(*F);
David Neto156783e2017-07-05 15:39:41 -0400135 // An ArgMapping describes how a kernel argument is remapped.
136 struct ArgMapping {
137 std::string name;
138 // 0-based argument index in the old kernel function.
139 unsigned old_index;
140 // 0-based argument index in the new kernel function.
David Netoc6f3ab22018-04-06 18:02:31 -0400141 int new_index;
David Neto156783e2017-07-05 15:39:41 -0400142 // Offset of the argument value within the new kernel argument.
143 // This is always zero for non-POD arguments. For a POD argument,
144 // this is the byte offset within the POD arguments struct.
145 unsigned offset;
Kévin PETITa353c832018-03-20 23:21:21 +0000146 // Size of the argument
147 unsigned arg_size;
Kévin Petit8bea15e2019-04-09 14:05:17 +0100148 // Argument type.
149 clspv::ArgKind arg_kind;
David Neto156783e2017-07-05 15:39:41 -0400150 };
151
David Netodd992212017-06-23 17:47:55 -0400152 // In OpenCL, kernel arguments are either pointers or POD. A composite with
Kévin Petit921c1ab2019-03-19 21:25:44 +0000153 // an element or member that is a pointer is not allowed. So we'll use POD
David Netodd992212017-06-23 17:47:55 -0400154 // as a shorthand for non-pointer.
155
156 SmallVector<Type *, 8> PtrArgTys;
157 SmallVector<Type *, 8> PodArgTys;
David Neto156783e2017-07-05 15:39:41 -0400158 SmallVector<ArgMapping, 8> RemapInfo;
alan-baker038e9242019-04-19 22:14:41 -0400159 DenseMap<Argument *, unsigned> PodIndexMap;
David Neto156783e2017-07-05 15:39:41 -0400160 unsigned arg_index = 0;
David Netoc6f3ab22018-04-06 18:02:31 -0400161 int new_index = 0;
alan-baker038e9242019-04-19 22:14:41 -0400162 unsigned pod_index = 0;
David Netodd992212017-06-23 17:47:55 -0400163 for (Argument &Arg : F->args()) {
164 Type *ArgTy = Arg.getType();
165 if (isa<PointerType>(ArgTy)) {
166 PtrArgTys.push_back(ArgTy);
alan-bakerc4579bb2020-04-29 14:15:50 -0400167 const auto kind = clspv::GetArgKind(Arg);
alan-bakerff6c9292020-05-04 08:32:09 -0400168 RemapInfo.push_back(
169 {std::string(Arg.getName()), arg_index, new_index++, 0u, 0u, kind});
David Netodd992212017-06-23 17:47:55 -0400170 } else {
alan-baker038e9242019-04-19 22:14:41 -0400171 PodIndexMap[&Arg] = pod_index++;
David Netodd992212017-06-23 17:47:55 -0400172 PodArgTys.push_back(ArgTy);
173 }
David Neto156783e2017-07-05 15:39:41 -0400174 arg_index++;
David Netodd992212017-06-23 17:47:55 -0400175 }
176
David Netodd992212017-06-23 17:47:55 -0400177 // Put the pointer arguments first, and then POD arguments struct last.
David Neto2ded02e2017-10-23 15:30:59 -0400178 // Use StructType::get so we reuse types where possible.
179 auto PodArgsStructTy = StructType::get(Context, PodArgTys);
David Netodd992212017-06-23 17:47:55 -0400180 SmallVector<Type *, 8> NewFuncParamTys(PtrArgTys);
alan-baker038e9242019-04-19 22:14:41 -0400181
alan-bakerc4579bb2020-04-29 14:15:50 -0400182 if (pod_arg_impl == clspv::PodArgImpl::kUBO &&
alan-baker038e9242019-04-19 22:14:41 -0400183 !clspv::Option::Std430UniformBufferLayout()) {
184 SmallVector<Type *, 16> PaddedPodArgTys;
185 const DataLayout DL(&M);
186 const auto StructLayout = DL.getStructLayout(PodArgsStructTy);
187 unsigned pod_index = 0;
188 for (auto &Arg : F->args()) {
189 auto arg_type = Arg.getType();
190 if (arg_type->isPointerTy())
191 continue;
192
193 // The frontend has validated individual POD arguments. When the
194 // unified struct is constructed, pad struct and array elements as
195 // necessary to achieve a 16-byte alignment.
196 if (arg_type->isStructTy() || arg_type->isArrayTy()) {
197 auto offset = StructLayout->getElementOffset(pod_index);
198 auto aligned = alignTo(offset, 16);
199 if (offset < aligned) {
200 auto int_ty = IntegerType::get(Context, 32);
201 auto char_ty = IntegerType::get(Context, 8);
202 size_t num_ints = (aligned - offset) / 4;
203 size_t num_chars = (aligned - offset) - (num_ints * 4);
204 assert((num_chars == 0 || clspv::Option::Int8Support()) &&
205 "Char in UBO struct without char support");
206 // Fix the index for the offset of the argument.
207 // Add char padding first.
208 PodIndexMap[&Arg] += num_ints + num_chars;
209 for (size_t i = 0; i < num_chars; ++i) {
210 PaddedPodArgTys.push_back(char_ty);
211 }
212 for (size_t i = 0; i < num_ints; ++i) {
213 PaddedPodArgTys.push_back(int_ty);
214 }
215 }
216 }
217 ++pod_index;
218 PaddedPodArgTys.push_back(arg_type);
219 }
220 PodArgsStructTy = StructType::get(Context, PaddedPodArgTys);
221 }
alan-baker6a3930b2020-05-21 10:09:11 -0400222
223 if (pod_arg_impl != clspv::PodArgImpl::kGlobalPushConstant) {
224 NewFuncParamTys.push_back(PodArgsStructTy);
225 }
David Netodd992212017-06-23 17:47:55 -0400226
David Neto156783e2017-07-05 15:39:41 -0400227 // We've recorded the remapping for pointer arguments. Now record the
228 // remapping for POD arguments.
229 {
Kévin PETITa353c832018-03-20 23:21:21 +0000230 const DataLayout DL(&M);
231 const auto StructLayout = DL.getStructLayout(PodArgsStructTy);
David Neto156783e2017-07-05 15:39:41 -0400232 arg_index = 0;
David Neto156783e2017-07-05 15:39:41 -0400233 for (Argument &Arg : F->args()) {
234 Type *ArgTy = Arg.getType();
235 if (!isa<PointerType>(ArgTy)) {
Kévin PETITa353c832018-03-20 23:21:21 +0000236 unsigned arg_size = DL.getTypeStoreSize(ArgTy);
alan-baker6a3930b2020-05-21 10:09:11 -0400237 unsigned offset = StructLayout->getElementOffset(PodIndexMap[&Arg]);
238 int remapped_index = new_index;
239 if (pod_arg_impl == clspv::PodArgImpl::kGlobalPushConstant) {
240 offset += global_push_constant_size;
241 remapped_index = -1;
242 }
243 RemapInfo.push_back({std::string(Arg.getName()), arg_index,
244 remapped_index, offset, arg_size, pod_arg_kind});
David Neto156783e2017-07-05 15:39:41 -0400245 }
246 arg_index++;
247 }
248 }
249
David Netodd992212017-06-23 17:47:55 -0400250 FunctionType *NewFuncTy =
251 FunctionType::get(F->getReturnType(), NewFuncParamTys, false);
252
253 // Create the new function and set key properties.
254 auto NewFunc = Function::Create(NewFuncTy, F->getLinkage());
255 // The new function adopts the real name so that linkage to the outside
256 // world remains the same.
257 NewFunc->setName(F->getName());
258 F->setName(NewFunc->getName().str() + ".inner");
259
260 NewFunc->setCallingConv(F->getCallingConv());
261 F->setCallingConv(CallingConv::SPIR_FUNC);
262
Kévin Petit921c1ab2019-03-19 21:25:44 +0000263 // Transfer attributes that don't apply to the POD arguments
264 // to the new functions.
265 auto Attributes = F->getAttributes();
266 SmallVector<std::pair<unsigned, AttributeSet>, 8> AttrBuildInfo;
267
268 // Return attributes have to come first
alan-baker56db84f2021-09-08 20:50:35 -0400269 const auto retAttrs = Attributes.getRetAttrs();
270 if (retAttrs.hasAttributes()) {
Kévin Petit921c1ab2019-03-19 21:25:44 +0000271 auto idx = AttributeList::ReturnIndex;
alan-baker56db84f2021-09-08 20:50:35 -0400272 AttrBuildInfo.push_back(std::make_pair(idx, retAttrs));
Kévin Petit921c1ab2019-03-19 21:25:44 +0000273 }
274
Kévin Petit8bea15e2019-04-09 14:05:17 +0100275 // Then attributes for non-POD parameters
Kévin Petit921c1ab2019-03-19 21:25:44 +0000276 for (auto &rinfo : RemapInfo) {
Kévin Petit8bea15e2019-04-09 14:05:17 +0100277 bool argIsPod = rinfo.arg_kind == clspv::ArgKind::Pod ||
alan-baker9b0ec3c2020-04-06 14:45:34 -0400278 rinfo.arg_kind == clspv::ArgKind::PodUBO ||
279 rinfo.arg_kind == clspv::ArgKind::PodPushConstant;
Kévin Petit8bea15e2019-04-09 14:05:17 +0100280 if (!argIsPod && Attributes.hasParamAttrs(rinfo.old_index)) {
Kévin Petit921c1ab2019-03-19 21:25:44 +0000281 auto idx = rinfo.new_index + AttributeList::FirstArgIndex;
alan-bakera6001ae2021-08-18 17:08:27 -0400282 auto attrs = Attributes.getParamAttrs(rinfo.old_index);
Kévin Petit921c1ab2019-03-19 21:25:44 +0000283 AttrBuildInfo.push_back(std::make_pair(idx, attrs));
284 }
285 }
286
alan-bakerbccf62c2019-03-29 10:32:41 -0400287 // And finally function attributes.
alan-baker56db84f2021-09-08 20:50:35 -0400288 const auto fnAttrs = Attributes.getFnAttrs();
289 if (fnAttrs.hasAttributes()) {
Kévin Petit921c1ab2019-03-19 21:25:44 +0000290 auto idx = AttributeList::FunctionIndex;
alan-baker56db84f2021-09-08 20:50:35 -0400291 AttrBuildInfo.push_back(std::make_pair(idx, fnAttrs));
Kévin Petit921c1ab2019-03-19 21:25:44 +0000292 }
alan-bakerbccf62c2019-03-29 10:32:41 -0400293 auto newAttributes = AttributeList::get(M.getContext(), AttrBuildInfo);
Kévin Petit921c1ab2019-03-19 21:25:44 +0000294 NewFunc->setAttributes(newAttributes);
295
David Netodd992212017-06-23 17:47:55 -0400296 // Move OpenCL kernel named attributes.
297 // TODO(dneto): Attributes starting with kernel_arg_* should be rewritten
298 // to reflect change in the argument shape.
alan-bakerc4579bb2020-04-29 14:15:50 -0400299 auto pod_md_name = clspv::PodArgsImplMetadataName();
David Netodd992212017-06-23 17:47:55 -0400300 std::vector<const char *> Metadatas{
301 "reqd_work_group_size", "kernel_arg_addr_space",
302 "kernel_arg_access_qual", "kernel_arg_type",
alan-bakerc4579bb2020-04-29 14:15:50 -0400303 "kernel_arg_base_type", "kernel_arg_type_qual",
304 pod_md_name.c_str()};
David Netodd992212017-06-23 17:47:55 -0400305 for (auto name : Metadatas) {
306 NewFunc->setMetadata(name, F->getMetadata(name));
307 F->setMetadata(name, nullptr);
308 }
309
David Neto156783e2017-07-05 15:39:41 -0400310 IRBuilder<> Builder(BasicBlock::Create(Context, "entry", NewFunc));
311
312 // Set kernel argument mapping metadata.
313 {
314 // Attach a metadata node named "kernel_arg_map" to the new kernel
315 // function. It is a tuple of nodes, each of which is a tuple for
316 // each argument, with members:
317 // - Argument name
318 // - Ordinal index in the original kernel function
319 // - Ordinal index in the new kernel function
320 // - Byte offset within the argument. This is always 0 for pointer
321 // arguments. For POD arguments this is the offest within the POD
322 // argument struct.
David Neto48f56a42017-10-06 16:44:25 -0400323 // - Argument type
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400324 LLVMContext &Context = M.getContext();
325 SmallVector<Metadata *, 8> mappings;
David Neto156783e2017-07-05 15:39:41 -0400326 for (auto &arg_mapping : RemapInfo) {
327 auto *name_md = MDString::get(Context, arg_mapping.name);
328 auto *old_index_md =
329 ConstantAsMetadata::get(Builder.getInt32(arg_mapping.old_index));
Diego Novillo3cc8d7a2019-04-10 13:30:34 -0400330 auto *new_index_md =
331 ConstantAsMetadata::get(Builder.getInt32(arg_mapping.new_index));
David Netoc6f3ab22018-04-06 18:02:31 -0400332 auto *offset_md =
David Neto156783e2017-07-05 15:39:41 -0400333 ConstantAsMetadata::get(Builder.getInt32(arg_mapping.offset));
Kévin PETITa353c832018-03-20 23:21:21 +0000334 auto *arg_size_md =
335 ConstantAsMetadata::get(Builder.getInt32(arg_mapping.arg_size));
Kévin Petit8bea15e2019-04-09 14:05:17 +0100336 auto argKindName = GetArgKindName(arg_mapping.arg_kind);
337 auto *argtype_md = MDString::get(Context, argKindName);
alan-bakerff6c9292020-05-04 08:32:09 -0400338 auto *arg_md =
339 MDNode::get(Context, {name_md, old_index_md, new_index_md,
340 offset_md, arg_size_md, argtype_md});
David Neto156783e2017-07-05 15:39:41 -0400341 mappings.push_back(arg_md);
342 }
343
alan-bakerff6c9292020-05-04 08:32:09 -0400344 NewFunc->setMetadata(clspv::KernelArgMapMetadataName(),
345 MDNode::get(Context, mappings));
David Neto156783e2017-07-05 15:39:41 -0400346 }
347
David Netodd992212017-06-23 17:47:55 -0400348 // Insert the function after the original, to preserve ordering
349 // in the module as much as possible.
350 auto &FunctionList = M.getFunctionList();
351 for (auto Iter = FunctionList.begin(), IterEnd = FunctionList.end();
352 Iter != IterEnd; ++Iter) {
353 if (&*Iter == F) {
354 FunctionList.insertAfter(Iter, NewFunc);
355 break;
356 }
357 }
358
359 // The body of the wrapper is essentially a call to the original function,
360 // but we have to unwrap the non-pointer arguments from the struct.
David Netodd992212017-06-23 17:47:55 -0400361
362 // Map the wrapper's arguments to the callee's arguments.
363 SmallVector<Argument *, 8> CallerArgs;
364 for (Argument &Arg : NewFunc->args()) {
365 CallerArgs.push_back(&Arg);
366 }
alan-baker6a3930b2020-05-21 10:09:11 -0400367 Value *PodArg = nullptr;
368 if (pod_arg_impl != clspv::PodArgImpl::kGlobalPushConstant) {
369 Argument *pod_arg = CallerArgs.back();
370 pod_arg->setName("podargs");
371 PodArg = pod_arg;
372 }
David Netodd992212017-06-23 17:47:55 -0400373
374 SmallVector<Value *, 8> CalleeArgs;
alan-baker038e9242019-04-19 22:14:41 -0400375 unsigned podCount = 0;
David Netodd992212017-06-23 17:47:55 -0400376 unsigned ptrIndex = 0;
alan-baker038e9242019-04-19 22:14:41 -0400377 for (Argument &Arg : F->args()) {
David Netodd992212017-06-23 17:47:55 -0400378 if (isa<PointerType>(Arg.getType())) {
379 CalleeArgs.push_back(CallerArgs[ptrIndex++]);
380 } else {
alan-baker038e9242019-04-19 22:14:41 -0400381 podCount++;
382 unsigned podIndex = PodIndexMap[&Arg];
alan-baker6a3930b2020-05-21 10:09:11 -0400383 if (pod_arg_impl == clspv::PodArgImpl::kGlobalPushConstant) {
384 auto reconstructed =
385 ConvertToType(M, PodArgsStructTy, podIndex, Builder);
386 CalleeArgs.push_back(reconstructed);
387 } else {
388 CalleeArgs.push_back(Builder.CreateExtractValue(PodArg, {podIndex}));
389 }
David Netodd992212017-06-23 17:47:55 -0400390 }
391 CalleeArgs.back()->setName(Arg.getName());
392 }
alan-baker038e9242019-04-19 22:14:41 -0400393 assert(ptrIndex + podCount == F->arg_size());
Kévin Petit98d9c332019-03-13 15:03:40 +0000394 assert(ptrIndex == PtrArgTys.size());
alan-baker038e9242019-04-19 22:14:41 -0400395 assert(podCount != 0);
396 assert(podCount == PodArgTys.size());
David Netodd992212017-06-23 17:47:55 -0400397
398 auto Call = Builder.CreateCall(F, CalleeArgs);
399 Call->setCallingConv(F->getCallingConv());
David Netod5b3f982017-09-28 14:49:49 -0400400 CallList.push_back(Call);
David Netodd992212017-06-23 17:47:55 -0400401
402 Builder.CreateRetVoid();
403 }
404
David Neto482550a2018-03-24 05:21:07 -0700405 // Inline the inner function. It's cleaner to do this.
406 for (CallInst *C : CallList) {
407 InlineFunctionInfo info;
alan-baker741fd1f2020-04-14 17:38:15 -0400408 Changed |= InlineFunction(*C, info).isSuccess();
David Netod5b3f982017-09-28 14:49:49 -0400409 }
410
David Netodd992212017-06-23 17:47:55 -0400411 return Changed;
412}
alan-baker6a3930b2020-05-21 10:09:11 -0400413
414StructType *
415ClusterPodKernelArgumentsPass::GetTypeMangledPodArgsStruct(Module &M) {
416 // If we are using global type-mangled push constants for any kernel we need
417 // to figure out what the shared representation will be. Calculate the max
418 // number of integers needed to satisfy all kernels.
419 uint64_t max_pod_args_size = 0;
420 const auto &DL = M.getDataLayout();
421 for (auto &F : M) {
422 if (F.isDeclaration() || F.getCallingConv() != CallingConv::SPIR_KERNEL)
423 continue;
424
425 auto pod_arg_impl = clspv::GetPodArgsImpl(F);
426 if (pod_arg_impl != clspv::PodArgImpl::kGlobalPushConstant)
427 continue;
428
429 SmallVector<Type *, 8> PodArgTys;
430 for (auto &Arg : F.args()) {
431 if (!Arg.getType()->isPointerTy()) {
432 PodArgTys.push_back(Arg.getType());
433 }
434 }
435
436 // TODO: The type-mangling code will need updated if we want to support
437 // packed structs.
438 auto struct_ty = StructType::get(M.getContext(), PodArgTys);
439 uint64_t size = alignTo(DL.getTypeStoreSize(struct_ty), kIntBytes);
440 if (size > max_pod_args_size)
441 max_pod_args_size = size;
442 }
443
444 if (max_pod_args_size > 0) {
445 auto int_ty = IntegerType::get(M.getContext(), 32);
446 std::vector<Type *> global_pod_arg_tys(max_pod_args_size / kIntBytes,
447 int_ty);
448 return StructType::create(M.getContext(), global_pod_arg_tys);
449 }
450
451 return nullptr;
452}
453
454void ClusterPodKernelArgumentsPass::RedeclareGlobalPushConstants(
455 Module &M, StructType *mangled_struct_ty) {
456 auto old_GV = M.getGlobalVariable(clspv::PushConstantsVariableName());
457
458 std::vector<Type *> push_constant_tys;
459 if (old_GV) {
460 auto block_ty =
461 cast<StructType>(old_GV->getType()->getPointerElementType());
462 for (auto ele : block_ty->elements())
463 push_constant_tys.push_back(ele);
464 }
465 push_constant_tys.push_back(mangled_struct_ty);
466
467 auto push_constant_ty = StructType::create(M.getContext(), push_constant_tys);
468 auto new_GV = new GlobalVariable(
469 M, push_constant_ty, false, GlobalValue::ExternalLinkage, nullptr, "",
470 nullptr, GlobalValue::ThreadLocalMode::NotThreadLocal,
471 clspv::AddressSpace::PushConstant);
472 new_GV->setInitializer(Constant::getNullValue(push_constant_ty));
473 std::vector<Metadata *> md_args;
474 if (old_GV) {
475 // Replace the old push constant variable metadata and uses.
476 new_GV->takeName(old_GV);
477 auto md = old_GV->getMetadata(clspv::PushConstantsMetadataName());
478 for (auto &op : md->operands()) {
479 md_args.push_back(op.get());
480 }
481 std::vector<User *> users;
482 for (auto user : old_GV->users())
483 users.push_back(user);
484 for (auto user : users) {
485 if (auto gep = dyn_cast<GetElementPtrInst>(user)) {
486 // Most uses are likely constant geps, but handle instructions first
487 // since we can only really access gep operators for the constant side.
488 SmallVector<Value *, 4> indices;
489 for (auto iter = gep->idx_begin(); iter != gep->idx_end(); ++iter) {
490 indices.push_back(*iter);
491 }
492 auto new_gep = GetElementPtrInst::Create(push_constant_ty, new_GV,
493 indices, "", gep);
494 new_gep->setIsInBounds(gep->isInBounds());
495 gep->replaceAllUsesWith(new_gep);
496 new_gep->eraseFromParent();
497 } else if (auto gep_operator = dyn_cast<GEPOperator>(user)) {
498 SmallVector<Constant *, 4> indices;
499 for (auto iter = gep_operator->idx_begin();
500 iter != gep_operator->idx_end(); ++iter) {
501 indices.push_back(cast<Constant>(*iter));
502 }
503 auto new_gep = ConstantExpr::getGetElementPtr(
504 push_constant_ty, new_GV, indices, gep_operator->isInBounds());
505 user->replaceAllUsesWith(new_gep);
506 } else {
507 assert(false && "unexpected global use");
508 }
509 }
510 old_GV->removeDeadConstantUsers();
511 old_GV->eraseFromParent();
512 } else {
513 new_GV->setName(clspv::PushConstantsVariableName());
514 }
515 // New metadata operand for the kernel arguments.
516 auto cst =
517 ConstantInt::get(IntegerType::get(M.getContext(), 32),
518 static_cast<int>(clspv::PushConstant::KernelArgument));
519 md_args.push_back(ConstantAsMetadata::get(cst));
520 new_GV->setMetadata(clspv::PushConstantsMetadataName(),
521 MDNode::get(M.getContext(), md_args));
522}
523
524Value *ClusterPodKernelArgumentsPass::ConvertToType(Module &M,
525 StructType *pod_struct,
526 unsigned index,
527 IRBuilder<> &builder) {
528 auto int32_ty = IntegerType::get(M.getContext(), 32);
529 const auto &DL = M.getDataLayout();
530 const auto struct_layout = DL.getStructLayout(pod_struct);
531 auto ele_ty = pod_struct->getElementType(index);
532 const auto ele_size = DL.getTypeStoreSize(ele_ty).getKnownMinSize();
533 auto ele_offset = struct_layout->getElementOffset(index);
534 const auto ele_start_index = ele_offset / kIntBytes; // round down
535 const auto ele_end_index =
536 (ele_offset + ele_size + kIntBytes - 1) / kIntBytes; // round up
537
538 // Load the right number of ints. We'll load at least one, but may load
539 // ele_size / 4 + 1 integers depending on the offset.
540 std::vector<Value *> int_elements;
541 uint32_t i = ele_start_index;
542 do {
543 auto gep = clspv::GetPushConstantPointer(
544 builder.GetInsertBlock(), clspv::PushConstant::KernelArgument,
545 {builder.getInt32(i)});
546 auto ld = builder.CreateLoad(int32_ty, gep);
547 int_elements.push_back(ld);
548 i++;
549 } while (i < ele_end_index);
550
551 return BuildFromElements(M, builder, ele_ty, ele_offset % kIntBytes, 0,
552 int_elements);
553}
554
555Value *ClusterPodKernelArgumentsPass::BuildFromElements(
556 Module &M, IRBuilder<> &builder, Type *dst_type, uint64_t base_offset,
557 uint64_t base_index, const std::vector<Value *> &elements) {
558 auto int32_ty = IntegerType::get(M.getContext(), 32);
559 const auto &DL = M.getDataLayout();
560 const auto dst_size = DL.getTypeStoreSize(dst_type).getKnownMinSize();
561 auto dst_array_ty = dyn_cast<ArrayType>(dst_type);
562 auto dst_vec_ty = dyn_cast<VectorType>(dst_type);
563
564 Value *dst = nullptr;
565 if (auto dst_struct_ty = dyn_cast<StructType>(dst_type)) {
566 // Create an insertvalue chain for each converted element.
567 auto struct_layout = DL.getStructLayout(dst_struct_ty);
568 for (uint32_t i = 0; i < dst_struct_ty->getNumElements(); ++i) {
569 auto ele_ty = dst_struct_ty->getTypeAtIndex(i);
570 const auto ele_offset = struct_layout->getElementOffset(i);
571 const auto index = base_index + (ele_offset / kIntBytes);
572 const auto offset = (base_offset + ele_offset) % kIntBytes;
573
574 auto tmp = BuildFromElements(M, builder, ele_ty, offset, index, elements);
575 dst = builder.CreateInsertValue(dst ? dst : UndefValue::get(dst_type),
576 tmp, {i});
577 }
578 } else if (dst_array_ty || dst_vec_ty) {
579 if (dst_vec_ty && dst_vec_ty->getPrimitiveSizeInBits() ==
580 int32_ty->getPrimitiveSizeInBits()) {
581 // Easy case is just a bitcast.
582 dst = builder.CreateBitCast(elements[base_index], dst_type);
583 } else if (dst_vec_ty &&
584 dst_vec_ty->getElementType()->getPrimitiveSizeInBits() <
585 int32_ty->getPrimitiveSizeInBits()) {
586 // Bitcast integers to a vector of the primitive type and then shuffle
587 // elements into the final vector.
588 //
589 // We need at most two integers to handle any case here.
590 auto ele_ty = dst_vec_ty->getElementType();
alan-bakerfb288112020-08-31 11:26:53 -0400591 uint32_t num_elements = dst_vec_ty->getElementCount().getKnownMinValue();
alan-baker6a3930b2020-05-21 10:09:11 -0400592 assert(num_elements <= 4 && "Unhandled large vectors");
593 uint32_t ratio = (int32_ty->getPrimitiveSizeInBits() /
alan-baker446b86e2020-09-30 13:01:01 -0400594 ele_ty->getPrimitiveSizeInBits());
alan-bakerb3e2b6d2020-06-24 23:59:57 -0400595 auto scaled_vec_ty = FixedVectorType::get(ele_ty, ratio);
alan-baker6a3930b2020-05-21 10:09:11 -0400596 Value *casts[2] = {UndefValue::get(scaled_vec_ty),
597 UndefValue::get(scaled_vec_ty)};
598 uint32_t num_ints = (num_elements + ratio - 1) / ratio; // round up
599 num_ints = std::max(num_ints, 1u);
600 for (uint32_t i = 0; i < num_ints; ++i) {
601 casts[i] =
602 builder.CreateBitCast(elements[base_index + i], scaled_vec_ty);
603 }
604 SmallVector<int, 4> indices(num_elements);
605 uint32_t i = 0;
606 std::generate_n(indices.data(), num_elements, [&i]() { return i++; });
607 dst = builder.CreateShuffleVector(casts[0], casts[1], indices);
608 } else {
609 // General case, break into elements and construct the composite type.
610 auto ele_ty = dst_vec_ty ? dst_vec_ty->getElementType()
611 : dst_array_ty->getElementType();
612 assert((DL.getTypeStoreSize(ele_ty).getKnownMinSize() < kIntBytes ||
613 base_offset == 0) &&
614 "Unexpected packed data format");
615 uint64_t ele_size = DL.getTypeStoreSize(ele_ty);
alan-bakerfb288112020-08-31 11:26:53 -0400616 uint32_t num_elements =
617 dst_vec_ty ? dst_vec_ty->getElementCount().getKnownMinValue()
618 : dst_array_ty->getNumElements();
alan-baker6a3930b2020-05-21 10:09:11 -0400619
620 // Arrays of shorts/halfs could be offset from the start of an int.
621 uint64_t bytes_consumed = 0;
622 for (uint32_t i = 0; i < num_elements; ++i) {
623 uint64_t ele_offset = (base_offset + bytes_consumed) % kIntBytes;
624 uint64_t ele_index =
625 base_index + (base_offset + bytes_consumed) / kIntBytes;
626 // Convert the element.
627 auto tmp = BuildFromElements(M, builder, ele_ty, ele_offset, ele_index,
628 elements);
629 if (dst_vec_ty) {
630 dst = builder.CreateInsertElement(
631 dst ? dst : UndefValue::get(dst_type), tmp, i);
632 } else {
633 dst = builder.CreateInsertValue(dst ? dst : UndefValue::get(dst_type),
634 tmp, {i});
635 }
636
637 // Track consumed bytes.
638 bytes_consumed += ele_size;
639 }
640 }
641 } else {
642 // Base case is scalar conversion.
643 if (dst_size < kIntBytes) {
644 dst = elements[base_index];
645 if (dst_type->isIntegerTy() && base_offset == 0) {
646 // Can generate a single truncate instruction in this case.
647 dst = builder.CreateTrunc(
648 dst, IntegerType::get(M.getContext(), dst_size * 8));
649 } else {
650 // Bitcast to a vector of |dst_type| and extract the right element. This
651 // avoids introducing i16 when converting to half.
652 uint32_t ratio = (int32_ty->getPrimitiveSizeInBits() /
alan-baker446b86e2020-09-30 13:01:01 -0400653 dst_type->getPrimitiveSizeInBits());
alan-bakerb3e2b6d2020-06-24 23:59:57 -0400654 auto vec_ty = FixedVectorType::get(dst_type, ratio);
alan-baker6a3930b2020-05-21 10:09:11 -0400655 dst = builder.CreateBitCast(dst, vec_ty);
656 dst = builder.CreateExtractElement(dst, base_offset / dst_size);
657 }
658 } else if (dst_size == kIntBytes) {
659 assert(base_offset == 0 && "Unexpected packed data format");
660 // Create a bit cast if necessary.
661 dst = elements[base_index];
662 if (dst_type != int32_ty)
663 dst = builder.CreateBitCast(dst, dst_type);
664 } else {
665 assert(base_offset == 0 && "Unexpected packed data format");
666 assert(dst_size == kIntBytes * 2 && "Expected 64-bit scalar");
667 // Round up to number of integers.
668 auto dst_int = IntegerType::get(M.getContext(), dst_size * 8);
669 auto zext0 = builder.CreateZExt(elements[base_index], dst_int);
670 auto zext1 = builder.CreateZExt(elements[base_index + 1], dst_int);
671 auto shl = builder.CreateShl(zext1, 32);
672 dst = builder.CreateOr({zext0, shl});
673 if (dst_type != dst->getType())
674 dst = builder.CreateBitCast(dst, dst_type);
675 }
676 }
677
678 return dst;
679}