blob: 78c517746d9dfffb433990b444adb0f628584643 [file] [log] [blame]
alan-bakera1be3322020-04-20 12:48:18 -04001// Copyright 2020 The Clspv Authors. All rights reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#include "llvm/IR/Constants.h"
16
17#include "Constants.h"
18#include "SpecConstant.h"
19
20using namespace llvm;
21
22namespace {
23
24void InitSpecConstantMetadata(Module *module) {
25 auto next_spec_id_md =
26 module->getOrInsertNamedMetadata(clspv::NextSpecConstantMetadataName());
27 next_spec_id_md->clearOperands();
28
29 // Start at 3 to accommodate workgroup size ids.
30 const uint32_t first_spec_id = 3;
31 auto id_const = ValueAsMetadata::getConstant(ConstantInt::get(
32 IntegerType::get(module->getContext(), 32), first_spec_id));
33 auto id_md = MDTuple::get(module->getContext(), {id_const});
34 next_spec_id_md->addOperand(id_md);
35
36 auto spec_constant_list_md =
37 module->getOrInsertNamedMetadata(clspv::SpecConstantMetadataName());
38 spec_constant_list_md->clearOperands();
39}
40
41} // namespace
42
43namespace clspv {
44
45const char *GetSpecConstantName(SpecConstant kind) {
46 switch (kind) {
47 case SpecConstant::kWorkgroupSizeX:
48 return "workgroup_size_x";
49 case SpecConstant::kWorkgroupSizeY:
50 return "workgroup_size_y";
51 case SpecConstant::kWorkgroupSizeZ:
52 return "workgroup_size_z";
53 case SpecConstant::kLocalMemorySize:
54 return "local_memory_size";
55 }
56 llvm::errs() << "Unhandled case in clspv::GetSpecConstantName: " << int(kind)
57 << "\n";
58 return "";
59}
60
61SpecConstant GetSpecConstantFromName(const std::string &name) {
62 if (name == "workgroup_size_x")
63 return SpecConstant::kWorkgroupSizeX;
64 else if (name == "workgroup_size_y")
65 return SpecConstant::kWorkgroupSizeY;
66 else if (name == "workgroup_size_z")
67 return SpecConstant::kWorkgroupSizeZ;
68 else if (name == "local_memory_size")
69 return SpecConstant::kLocalMemorySize;
70
71 llvm::errs() << "Unhandled csae in clspv::GetSpecConstantFromName: " << name
72 << "\n";
73 return SpecConstant::kWorkgroupSizeX;
74}
75
76void AddWorkgroupSpecConstants(Module *module) {
77 auto spec_constant_list_md =
78 module->getNamedMetadata(SpecConstantMetadataName());
79 if (!spec_constant_list_md) {
80 InitSpecConstantMetadata(module);
81 spec_constant_list_md =
82 module->getNamedMetadata(SpecConstantMetadataName());
83 }
84
85 // Workgroup size spec constants always occupy ids 0, 1 and 2.
86 auto enum_const = ValueAsMetadata::getConstant(
87 ConstantInt::get(IntegerType::get(module->getContext(), 32),
88 static_cast<uint64_t>(SpecConstant::kWorkgroupSizeX)));
89 auto id_const = ValueAsMetadata::getConstant(
90 ConstantInt::get(IntegerType::get(module->getContext(), 32), 0));
91 auto wg_md = MDTuple::get(module->getContext(), {enum_const, id_const});
92 spec_constant_list_md->addOperand(wg_md);
93
94 enum_const = ValueAsMetadata::getConstant(
95 ConstantInt::get(IntegerType::get(module->getContext(), 32),
96 static_cast<uint64_t>(SpecConstant::kWorkgroupSizeY)));
97 id_const = ValueAsMetadata::getConstant(
98 ConstantInt::get(IntegerType::get(module->getContext(), 32), 1));
99 wg_md = MDTuple::get(module->getContext(), {enum_const, id_const});
100 spec_constant_list_md->addOperand(wg_md);
101
102 enum_const = ValueAsMetadata::getConstant(
103 ConstantInt::get(IntegerType::get(module->getContext(), 32),
104 static_cast<uint64_t>(SpecConstant::kWorkgroupSizeZ)));
105 id_const = ValueAsMetadata::getConstant(
106 ConstantInt::get(IntegerType::get(module->getContext(), 32), 2));
107 wg_md = MDTuple::get(module->getContext(), {enum_const, id_const});
108 spec_constant_list_md->addOperand(wg_md);
109}
110
111uint32_t AllocateSpecConstant(Module *module, SpecConstant kind) {
112 auto spec_constant_id_md =
113 module->getNamedMetadata(NextSpecConstantMetadataName());
114 if (!spec_constant_id_md) {
115 InitSpecConstantMetadata(module);
116 spec_constant_id_md =
117 module->getNamedMetadata(NextSpecConstantMetadataName());
118 }
119
120 auto value_md = spec_constant_id_md->getOperand(0);
121 auto value = cast<ConstantInt>(
122 dyn_cast<ValueAsMetadata>(value_md->getOperand(0))->getValue());
123 uint32_t next_id = static_cast<uint32_t>(value->getZExtValue());
124 // Update the next available id.
125 value_md->replaceOperandWith(
126 0, ValueAsMetadata::getConstant(ConstantInt::get(
127 IntegerType::get(module->getContext(), 32), next_id + 1)));
128
129 // Add the allocation to the metadata list.
130 auto spec_constant_list_md =
131 module->getNamedMetadata(SpecConstantMetadataName());
132 auto enum_const = ValueAsMetadata::getConstant(ConstantInt::get(
133 IntegerType::get(module->getContext(), 32), static_cast<uint64_t>(kind)));
134 auto id_const = ValueAsMetadata::getConstant(
135 ConstantInt::get(IntegerType::get(module->getContext(), 32), next_id));
136 auto wg_md = MDTuple::get(module->getContext(), {enum_const, id_const});
137 spec_constant_list_md->addOperand(wg_md);
138
139 return next_id;
140}
141
142std::vector<std::pair<SpecConstant, uint32_t>>
143GetSpecConstants(Module *module) {
144 std::vector<std::pair<SpecConstant, uint32_t>> spec_constants;
145 auto spec_constant_md =
146 module->getNamedMetadata(clspv::SpecConstantMetadataName());
147 if (!spec_constant_md)
148 return spec_constants;
149
150 for (auto pair : spec_constant_md->operands()) {
151 // Metadata is formatted as pairs of <SpecConstant, id>.
152 auto kind = static_cast<SpecConstant>(
153 cast<ConstantInt>(
154 cast<ValueAsMetadata>(pair->getOperand(0))->getValue())
155 ->getZExtValue());
156
157 uint32_t spec_id = static_cast<uint32_t>(
158 cast<ConstantInt>(
159 cast<ValueAsMetadata>(pair->getOperand(1))->getValue())
160 ->getZExtValue());
161 spec_constants.emplace_back(kind, spec_id);
162 }
163
164 return spec_constants;
165}
166
167} // namespace clspv