blob: 8be3ef922caf06f6cece000765b4a82a60281ede [file] [log] [blame]
Greg Fischer04fcc662016-11-10 10:11:50 -07001// Copyright (c) 2017 The Khronos Group Inc.
2// Copyright (c) 2017 Valve Corporation
3// Copyright (c) 2017 LunarG Inc.
4//
5// Licensed under the Apache License, Version 2.0 (the "License");
6// you may not use this file except in compliance with the License.
7// You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing, software
12// distributed under the License is distributed on an "AS IS" BASIS,
13// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14// See the License for the specific language governing permissions and
15// limitations under the License.
16
17#include "inline_pass.h"
18
19// Indices of operands in SPIR-V instructions
20
21static const int kSpvEntryPointFunctionId = 1;
22static const int kSpvFunctionCallFunctionId = 2;
23static const int kSpvFunctionCallArgumentId = 3;
24static const int kSpvReturnValueId = 0;
25static const int kSpvTypePointerStorageClass = 1;
26static const int kSpvTypePointerTypeId = 2;
27
28namespace spvtools {
29namespace opt {
30
31uint32_t InlinePass::FindPointerToType(uint32_t type_id,
32 SpvStorageClass storage_class) {
33 ir::Module::inst_iterator type_itr = module_->types_values_begin();
34 for (; type_itr != module_->types_values_end(); ++type_itr) {
35 const ir::Instruction* type_inst = &*type_itr;
36 if (type_inst->opcode() == SpvOpTypePointer &&
37 type_inst->GetSingleWordOperand(kSpvTypePointerTypeId) == type_id &&
38 type_inst->GetSingleWordOperand(kSpvTypePointerStorageClass) ==
39 storage_class)
40 return type_inst->result_id();
41 }
42 return 0;
43}
44
45uint32_t InlinePass::AddPointerToType(uint32_t type_id,
46 SpvStorageClass storage_class) {
47 uint32_t resultId = TakeNextId();
48 std::unique_ptr<ir::Instruction> type_inst(new ir::Instruction(
49 SpvOpTypePointer, 0, resultId,
50 {{spv_operand_type_t::SPV_OPERAND_TYPE_STORAGE_CLASS,
51 {uint32_t(storage_class)}},
52 {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {type_id}}}));
53 module_->AddType(std::move(type_inst));
54 return resultId;
55}
56
57void InlinePass::AddBranch(uint32_t label_id,
58 std::unique_ptr<ir::BasicBlock>* block_ptr) {
59 std::unique_ptr<ir::Instruction> newBranch(new ir::Instruction(
60 SpvOpBranch, 0, 0,
61 {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {label_id}}}));
62 (*block_ptr)->AddInstruction(std::move(newBranch));
63}
64
65void InlinePass::AddStore(uint32_t ptr_id, uint32_t val_id,
66 std::unique_ptr<ir::BasicBlock>* block_ptr) {
67 std::unique_ptr<ir::Instruction> newStore(new ir::Instruction(
68 SpvOpStore, 0, 0, {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {ptr_id}},
69 {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {val_id}}}));
70 (*block_ptr)->AddInstruction(std::move(newStore));
71}
72
73void InlinePass::AddLoad(uint32_t type_id, uint32_t resultId, uint32_t ptr_id,
74 std::unique_ptr<ir::BasicBlock>* block_ptr) {
75 std::unique_ptr<ir::Instruction> newLoad(new ir::Instruction(
76 SpvOpLoad, type_id, resultId,
77 {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {ptr_id}}}));
78 (*block_ptr)->AddInstruction(std::move(newLoad));
79}
80
81std::unique_ptr<ir::Instruction> InlinePass::NewLabel(uint32_t label_id) {
82 std::unique_ptr<ir::Instruction> newLabel(
83 new ir::Instruction(SpvOpLabel, 0, label_id, {}));
84 return newLabel;
85}
86
87void InlinePass::MapParams(
88 ir::Function* calleeFn,
89 ir::UptrVectorIterator<ir::Instruction> call_inst_itr,
90 std::unordered_map<uint32_t, uint32_t>* callee2caller) {
91 int param_idx = 0;
92 calleeFn->ForEachParam(
93 [&call_inst_itr, &param_idx, &callee2caller](const ir::Instruction* cpi) {
94 const uint32_t pid = cpi->result_id();
95 (*callee2caller)[pid] = call_inst_itr->GetSingleWordOperand(
96 kSpvFunctionCallArgumentId + param_idx);
97 param_idx++;
98 });
99}
100
101void InlinePass::CloneAndMapLocals(
102 ir::Function* calleeFn,
103 std::vector<std::unique_ptr<ir::Instruction>>* new_vars,
104 std::unordered_map<uint32_t, uint32_t>* callee2caller) {
105 auto callee_block_itr = calleeFn->begin();
106 auto callee_var_itr = callee_block_itr->begin();
107 while (callee_var_itr->opcode() == SpvOp::SpvOpVariable) {
108 std::unique_ptr<ir::Instruction> var_inst(
109 new ir::Instruction(*callee_var_itr));
110 uint32_t newId = TakeNextId();
111 var_inst->SetResultId(newId);
112 (*callee2caller)[callee_var_itr->result_id()] = newId;
113 new_vars->push_back(std::move(var_inst));
114 callee_var_itr++;
115 }
116}
117
118uint32_t InlinePass::CreateReturnVar(
119 ir::Function* calleeFn,
120 std::vector<std::unique_ptr<ir::Instruction>>* new_vars) {
121 uint32_t returnVarId = 0;
122 const uint32_t calleeTypeId = calleeFn->type_id();
123 const ir::Instruction* calleeType =
124 def_use_mgr_->id_to_defs().find(calleeTypeId)->second;
125 if (calleeType->opcode() != SpvOpTypeVoid) {
126 // Find or create ptr to callee return type.
127 uint32_t returnVarTypeId =
128 FindPointerToType(calleeTypeId, SpvStorageClassFunction);
129 if (returnVarTypeId == 0)
130 returnVarTypeId = AddPointerToType(calleeTypeId, SpvStorageClassFunction);
131 // Add return var to new function scope variables.
132 returnVarId = TakeNextId();
133 std::unique_ptr<ir::Instruction> var_inst(new ir::Instruction(
134 SpvOpVariable, returnVarTypeId, returnVarId,
135 {{spv_operand_type_t::SPV_OPERAND_TYPE_STORAGE_CLASS,
136 {SpvStorageClassFunction}}}));
137 new_vars->push_back(std::move(var_inst));
138 }
139 return returnVarId;
140}
141
142bool InlinePass::IsSameBlockOp(const ir::Instruction* inst) const {
143 return inst->opcode() == SpvOpSampledImage || inst->opcode() == SpvOpImage;
144}
145
146void InlinePass::CloneSameBlockOps(
147 std::unique_ptr<ir::Instruction>* inst,
148 std::unordered_map<uint32_t, uint32_t>* postCallSB,
149 std::unordered_map<uint32_t, ir::Instruction*>* preCallSB,
150 std::unique_ptr<ir::BasicBlock>* block_ptr) {
151 (*inst)
152 ->ForEachInId([&postCallSB, &preCallSB, &block_ptr, this](uint32_t* iid) {
153 const auto mapItr = (*postCallSB).find(*iid);
154 if (mapItr == (*postCallSB).end()) {
155 const auto mapItr2 = (*preCallSB).find(*iid);
156 if (mapItr2 != (*preCallSB).end()) {
157 // Clone pre-call same-block ops, map result id.
158 const ir::Instruction* inInst = mapItr2->second;
159 std::unique_ptr<ir::Instruction> sb_inst(
160 new ir::Instruction(*inInst));
161 CloneSameBlockOps(&sb_inst, postCallSB, preCallSB, block_ptr);
162 const uint32_t rid = sb_inst->result_id();
163 const uint32_t nid = this->TakeNextId();
164 sb_inst->SetResultId(nid);
165 (*postCallSB)[rid] = nid;
166 *iid = nid;
167 (*block_ptr)->AddInstruction(std::move(sb_inst));
168 }
169 } else {
170 // Reset same-block op operand.
171 *iid = mapItr->second;
172 }
173 });
174}
175
176void InlinePass::GenInlineCode(
177 std::vector<std::unique_ptr<ir::BasicBlock>>* new_blocks,
178 std::vector<std::unique_ptr<ir::Instruction>>* new_vars,
179 ir::UptrVectorIterator<ir::Instruction> call_inst_itr,
180 ir::UptrVectorIterator<ir::BasicBlock> call_block_itr) {
181 // Map from all ids in the callee to their equivalent id in the caller
182 // as callee instructions are copied into caller.
183 std::unordered_map<uint32_t, uint32_t> callee2caller;
184 // Pre-call same-block insts
185 std::unordered_map<uint32_t, ir::Instruction*> preCallSB;
186 // Post-call same-block op ids
187 std::unordered_map<uint32_t, uint32_t> postCallSB;
188
189 ir::Function* calleeFn = id2function_[call_inst_itr->GetSingleWordOperand(
190 kSpvFunctionCallFunctionId)];
191
192 // Map parameters to actual arguments.
193 MapParams(calleeFn, call_inst_itr, &callee2caller);
194
195 // Define caller local variables for all callee variables and create map to
196 // them.
197 CloneAndMapLocals(calleeFn, new_vars, &callee2caller);
198
199 // Create return var if needed.
200 uint32_t returnVarId = CreateReturnVar(calleeFn, new_vars);
201
202 // Clone and map callee code. Copy caller block code to beginning of
203 // first block and end of last block.
204 bool prevInstWasReturn = false;
205 uint32_t returnLabelId = 0;
206 bool multiBlocks = false;
207 const uint32_t calleeTypeId = calleeFn->type_id();
208 std::unique_ptr<ir::BasicBlock> new_blk_ptr;
209 calleeFn->ForEachInst([&new_blocks, &callee2caller, &call_block_itr,
210 &call_inst_itr, &new_blk_ptr, &prevInstWasReturn,
211 &returnLabelId, &returnVarId, &calleeTypeId,
212 &multiBlocks, &postCallSB, &preCallSB, this](
213 const ir::Instruction* cpi) {
214 switch (cpi->opcode()) {
215 case SpvOpFunction:
216 case SpvOpFunctionParameter:
217 case SpvOpVariable:
218 // Already processed
219 break;
220 case SpvOpLabel: {
221 // If previous instruction was early return, insert branch
222 // instruction to return block.
223 if (prevInstWasReturn) {
224 if (returnLabelId == 0) returnLabelId = this->TakeNextId();
225 AddBranch(returnLabelId, &new_blk_ptr);
226 prevInstWasReturn = false;
227 }
228 // Finish current block (if it exists) and get label for next block.
229 uint32_t labelId;
230 bool firstBlock = false;
231 if (new_blk_ptr != nullptr) {
232 new_blocks->push_back(std::move(new_blk_ptr));
233 // If result id is already mapped, use it, otherwise get a new
234 // one.
235 const uint32_t rid = cpi->result_id();
236 const auto mapItr = callee2caller.find(rid);
237 labelId = (mapItr != callee2caller.end()) ? mapItr->second
238 : this->TakeNextId();
239 } else {
240 // First block needs to use label of original block
241 // but map callee label in case of phi reference.
242 labelId = call_block_itr->label_id();
243 callee2caller[cpi->result_id()] = labelId;
244 firstBlock = true;
245 }
246 // Create first/next block.
247 new_blk_ptr.reset(new ir::BasicBlock(NewLabel(labelId)));
248 if (firstBlock) {
249 // Copy contents of original caller block up to call instruction.
250 for (auto cii = call_block_itr->begin(); cii != call_inst_itr;
251 cii++) {
252 std::unique_ptr<ir::Instruction> cp_inst(new ir::Instruction(*cii));
253 // Remember same-block ops for possible regeneration.
254 if (IsSameBlockOp(&*cp_inst)) {
255 auto* sb_inst_ptr = cp_inst.get();
256 preCallSB[cp_inst->result_id()] = sb_inst_ptr;
257 }
258 new_blk_ptr->AddInstruction(std::move(cp_inst));
259 }
260 } else {
261 multiBlocks = true;
262 }
263 } break;
264 case SpvOpReturnValue: {
265 // Store return value to return variable.
266 assert(returnVarId != 0);
267 uint32_t valId = cpi->GetInOperand(kSpvReturnValueId).words[0];
268 const auto mapItr = callee2caller.find(valId);
269 if (mapItr != callee2caller.end()) {
270 valId = mapItr->second;
271 }
272 AddStore(returnVarId, valId, &new_blk_ptr);
273
274 // Remember we saw a return; if followed by a label, will need to
275 // insert branch.
276 prevInstWasReturn = true;
277 } break;
278 case SpvOpReturn: {
279 // Remember we saw a return; if followed by a label, will need to
280 // insert branch.
281 prevInstWasReturn = true;
282 } break;
283 case SpvOpFunctionEnd: {
284 // If there was an early return, create return label/block.
285 // If previous instruction was return, insert branch instruction
286 // to return block.
287 if (returnLabelId != 0) {
288 if (prevInstWasReturn) AddBranch(returnLabelId, &new_blk_ptr);
289 new_blocks->push_back(std::move(new_blk_ptr));
290 new_blk_ptr.reset(new ir::BasicBlock(NewLabel(returnLabelId)));
291 multiBlocks = true;
292 }
293 // Load return value into result id of call, if it exists.
294 if (returnVarId != 0) {
295 const uint32_t resId = call_inst_itr->result_id();
296 assert(resId != 0);
297 AddLoad(calleeTypeId, resId, returnVarId, &new_blk_ptr);
298 }
299 // Copy remaining instructions from caller block.
300 auto cii = call_inst_itr;
301 for (cii++; cii != call_block_itr->end(); cii++) {
302 std::unique_ptr<ir::Instruction> cp_inst(new ir::Instruction(*cii));
303 // If multiple blocks generated, regenerate any same-block
304 // instruction that has not been seen in this last block.
305 if (multiBlocks) {
306 CloneSameBlockOps(&cp_inst, &postCallSB, &preCallSB, &new_blk_ptr);
307 // Remember same-block ops in this block.
308 if (IsSameBlockOp(&*cp_inst)) {
309 const uint32_t rid = cp_inst->result_id();
310 postCallSB[rid] = rid;
311 }
312 }
313 new_blk_ptr->AddInstruction(std::move(cp_inst));
314 }
315 // Finalize inline code.
316 new_blocks->push_back(std::move(new_blk_ptr));
317 } break;
318 default: {
319 // Copy callee instruction and remap all input Ids.
320 std::unique_ptr<ir::Instruction> cp_inst(new ir::Instruction(*cpi));
321 cp_inst->ForEachInId([&callee2caller, &cpi, this](uint32_t* iid) {
322 const auto mapItr = callee2caller.find(*iid);
323 if (mapItr != callee2caller.end()) {
324 *iid = mapItr->second;
325 } else if (cpi->has_labels()) {
326 const ir::Instruction* inst =
327 def_use_mgr_->id_to_defs().find(*iid)->second;
328 if (inst->opcode() == SpvOpLabel) {
329 // Forward label reference. Allocate a new label id, map it,
330 // use it and check for it at each label.
331 const uint32_t nid = this->TakeNextId();
332 callee2caller[*iid] = nid;
333 *iid = nid;
334 }
335 }
336 });
337 // Map and reset result id.
338 const uint32_t rid = cp_inst->result_id();
339 if (rid != 0) {
340 const uint32_t nid = this->TakeNextId();
341 callee2caller[rid] = nid;
342 cp_inst->SetResultId(nid);
343 }
344 new_blk_ptr->AddInstruction(std::move(cp_inst));
345 } break;
346 }
347 });
348 // Update block map given replacement blocks.
349 for (auto& blk : *new_blocks) {
350 id2block_[blk->label_id()] = &*blk;
351 }
352}
353
David Netoceb1d4f2017-03-31 10:36:58 -0400354bool InlinePass::IsInlinableFunctionCall(const ir::Instruction* inst) {
355 if (inst->opcode() != SpvOp::SpvOpFunctionCall) return false;
356 const ir::Function* calleeFn =
357 id2function_[inst->GetSingleWordOperand(kSpvFunctionCallFunctionId)];
358 // We can only inline a function if it has blocks.
359 return calleeFn->cbegin() != calleeFn->cend();
360}
361
Greg Fischer04fcc662016-11-10 10:11:50 -0700362bool InlinePass::Inline(ir::Function* func) {
363 bool modified = false;
364 // Using block iterators here because of block erasures and insertions.
365 for (auto bi = func->begin(); bi != func->end(); bi++) {
366 for (auto ii = bi->begin(); ii != bi->end();) {
David Netoceb1d4f2017-03-31 10:36:58 -0400367 if (IsInlinableFunctionCall(&*ii)) {
Greg Fischer04fcc662016-11-10 10:11:50 -0700368 // Inline call.
369 std::vector<std::unique_ptr<ir::BasicBlock>> newBlocks;
370 std::vector<std::unique_ptr<ir::Instruction>> newVars;
371 GenInlineCode(&newBlocks, &newVars, ii, bi);
372 // Update phi functions in successor blocks if call block
373 // is replaced with more than one block.
374 if (newBlocks.size() > 1) {
375 const auto firstBlk = newBlocks.begin();
376 const auto lastBlk = newBlocks.end() - 1;
377 const uint32_t firstId = (*firstBlk)->label_id();
378 const uint32_t lastId = (*lastBlk)->label_id();
David Netoceb1d4f2017-03-31 10:36:58 -0400379 (*lastBlk)->ForEachSuccessorLabel(
380 [&firstId, &lastId, this](uint32_t succ) {
Greg Fischer04fcc662016-11-10 10:11:50 -0700381 ir::BasicBlock* sbp = this->id2block_[succ];
382 sbp->ForEachPhiInst([&firstId, &lastId](ir::Instruction* phi) {
383 phi->ForEachInId([&firstId, &lastId](uint32_t* id) {
384 if (*id == firstId) *id = lastId;
385 });
386 });
387 });
388 }
389 // Replace old calling block with new block(s).
390 bi = bi.Erase();
391 bi = bi.InsertBefore(&newBlocks);
392 // Insert new function variables.
393 if (newVars.size() > 0) func->begin()->begin().InsertBefore(&newVars);
394 // Restart inlining at beginning of calling block.
395 ii = bi->begin();
396 modified = true;
397 } else {
398 ii++;
399 }
400 }
401 }
402 return modified;
403}
404
405void InlinePass::Initialize(ir::Module* module) {
406 def_use_mgr_.reset(new analysis::DefUseManager(consumer(), module));
407
408 // Initialize next unused Id.
409 next_id_ = module->id_bound();
410
411 // Save module.
412 module_ = module;
413
414 // Initialize function and block maps.
415 id2function_.clear();
416 id2block_.clear();
417 for (auto& fn : *module_) {
418 id2function_[fn.result_id()] = &fn;
419 for (auto& blk : fn) {
420 id2block_[blk.label_id()] = &blk;
421 }
422 }
423};
424
425Pass::Status InlinePass::ProcessImpl() {
426 // Do exhaustive inlining on each entry point function in module
427 bool modified = false;
428 for (auto& e : module_->entry_points()) {
429 ir::Function* fn =
430 id2function_[e.GetSingleWordOperand(kSpvEntryPointFunctionId)];
431 modified = modified || Inline(fn);
432 }
433
434 FinalizeNextId(module_);
435
436 return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange;
437}
438
439InlinePass::InlinePass()
440 : module_(nullptr), def_use_mgr_(nullptr), next_id_(0) {}
441
442Pass::Status InlinePass::Process(ir::Module* module) {
443 Initialize(module);
444 return ProcessImpl();
445}
446
447} // namespace opt
448} // namespace spvtools