blob: 54f20b2e0cfaaaaa9eb46d11b2005b3acbfd1500 [file] [log] [blame]
Ben Clayton2101c352021-02-10 21:22:03 +00001// Copyright 2020 The Tint Authors.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#include "src/transform/msl.h"
16
James Price960aa2e2021-06-19 00:33:35 +000017#include <memory>
James Price7a47fa82021-05-26 15:41:02 +000018#include <unordered_map>
Ben Clayton2101c352021-02-10 21:22:03 +000019#include <utility>
James Price7a47fa82021-05-26 15:41:02 +000020#include <vector>
Ben Clayton2101c352021-02-10 21:22:03 +000021
James Price7a47fa82021-05-26 15:41:02 +000022#include "src/ast/disable_validation_decoration.h"
23#include "src/program_builder.h"
24#include "src/sem/call.h"
25#include "src/sem/function.h"
26#include "src/sem/statement.h"
27#include "src/sem/variable.h"
James Price960aa2e2021-06-19 00:33:35 +000028#include "src/transform/array_length_from_uniform.h"
James Pricef8f31a42021-04-09 13:50:38 +000029#include "src/transform/canonicalize_entry_point_io.h"
Brandon Jonesc705b6c2021-05-10 16:15:31 +000030#include "src/transform/external_texture_transform.h"
James Price567f2e42021-06-18 09:47:23 +000031#include "src/transform/inline_pointer_lets.h"
James Pricef8f31a42021-04-09 13:50:38 +000032#include "src/transform/manager.h"
Ben Clayton31936f32021-06-16 09:50:11 +000033#include "src/transform/pad_array_elements.h"
James Price42220ba2021-06-01 12:08:20 +000034#include "src/transform/promote_initializers_to_const_var.h"
James Price567f2e42021-06-18 09:47:23 +000035#include "src/transform/simplify.h"
Ben Clayton0597a2b2021-06-16 09:19:36 +000036#include "src/transform/wrap_arrays_in_structs.h"
Ben Clayton75db82c2021-06-18 22:44:31 +000037#include "src/transform/zero_init_workgroup_memory.h"
Ben Clayton2101c352021-02-10 21:22:03 +000038
Ben Claytonb5cd10c2021-06-25 10:26:26 +000039TINT_INSTANTIATE_TYPEINFO(tint::transform::Msl);
James Price960aa2e2021-06-19 00:33:35 +000040TINT_INSTANTIATE_TYPEINFO(tint::transform::Msl::Config);
41TINT_INSTANTIATE_TYPEINFO(tint::transform::Msl::Result);
42
Ben Clayton2101c352021-02-10 21:22:03 +000043namespace tint {
44namespace transform {
45
46Msl::Msl() = default;
47Msl::~Msl() = default;
48
James Price960aa2e2021-06-19 00:33:35 +000049Output Msl::Run(const Program* in, const DataMap& inputs) {
James Pricef8f31a42021-04-09 13:50:38 +000050 Manager manager;
James Price960aa2e2021-06-19 00:33:35 +000051 DataMap internal_inputs;
52
53 auto* cfg = inputs.Get<Config>();
54
James Pricec32e8f62021-06-22 20:08:29 +000055 // Build the configs for the internal transforms.
James Price960aa2e2021-06-19 00:33:35 +000056 uint32_t buffer_size_ubo_index = kDefaultBufferSizeUniformIndex;
James Pricec32e8f62021-06-22 20:08:29 +000057 uint32_t fixed_sample_mask = 0xFFFFFFFF;
James Price960aa2e2021-06-19 00:33:35 +000058 if (cfg) {
59 buffer_size_ubo_index = cfg->buffer_size_ubo_index;
James Pricec32e8f62021-06-22 20:08:29 +000060 fixed_sample_mask = cfg->fixed_sample_mask;
James Price960aa2e2021-06-19 00:33:35 +000061 }
62 auto array_length_from_uniform_cfg = ArrayLengthFromUniform::Config(
63 sem::BindingPoint{0, buffer_size_ubo_index});
James Pricec32e8f62021-06-22 20:08:29 +000064 auto entry_point_io_cfg = CanonicalizeEntryPointIO::Config(
65 CanonicalizeEntryPointIO::BuiltinStyle::kParameter, fixed_sample_mask);
James Price960aa2e2021-06-19 00:33:35 +000066
67 // Use the SSBO binding numbers as the indices for the buffer size lookups.
68 for (auto* var : in->AST().GlobalVariables()) {
Ben Clayton0f2d95d2021-07-22 13:24:59 +000069 auto* global = in->Sem().Get<sem::GlobalVariable>(var);
70 if (global && global->StorageClass() == ast::StorageClass::kStorage) {
James Price960aa2e2021-06-19 00:33:35 +000071 array_length_from_uniform_cfg.bindpoint_to_size_index.emplace(
Ben Clayton0f2d95d2021-07-22 13:24:59 +000072 global->BindingPoint(), global->BindingPoint().binding);
James Price960aa2e2021-06-19 00:33:35 +000073 }
74 }
75
Ben Clayton701820b2021-07-20 18:23:06 +000076 if (!cfg || !cfg->disable_workgroup_init) {
77 // ZeroInitWorkgroupMemory must come before CanonicalizeEntryPointIO as
78 // ZeroInitWorkgroupMemory may inject new builtin parameters.
79 manager.Add<ZeroInitWorkgroupMemory>();
80 }
James Pricef8f31a42021-04-09 13:50:38 +000081 manager.Add<CanonicalizeEntryPointIO>();
Brandon Jonesc705b6c2021-05-10 16:15:31 +000082 manager.Add<ExternalTextureTransform>();
James Price42220ba2021-06-01 12:08:20 +000083 manager.Add<PromoteInitializersToConstVar>();
Ben Clayton0597a2b2021-06-16 09:19:36 +000084 manager.Add<WrapArraysInStructs>();
Ben Clayton31936f32021-06-16 09:50:11 +000085 manager.Add<PadArrayElements>();
James Price567f2e42021-06-18 09:47:23 +000086 manager.Add<InlinePointerLets>();
87 manager.Add<Simplify>();
James Price960aa2e2021-06-19 00:33:35 +000088 // ArrayLengthFromUniform must come after InlinePointerLets and Simplify, as
89 // it assumes that the form of the array length argument is &var.array.
90 manager.Add<ArrayLengthFromUniform>();
91 internal_inputs.Add<ArrayLengthFromUniform::Config>(
92 std::move(array_length_from_uniform_cfg));
93 internal_inputs.Add<CanonicalizeEntryPointIO::Config>(
James Pricec32e8f62021-06-22 20:08:29 +000094 std::move(entry_point_io_cfg));
James Price960aa2e2021-06-19 00:33:35 +000095 auto out = manager.Run(in, internal_inputs);
James Pricef8f31a42021-04-09 13:50:38 +000096 if (!out.program.IsValid()) {
97 return out;
98 }
James Price7a47fa82021-05-26 15:41:02 +000099
100 ProgramBuilder builder;
101 CloneContext ctx(&builder, &out.program);
102 // TODO(jrprice): Consider making this a standalone transform, with target
103 // storage class(es) as transform options.
James Price830b97f2021-06-11 12:34:26 +0000104 HandleModuleScopeVariables(ctx);
James Price7a47fa82021-05-26 15:41:02 +0000105 ctx.Clone();
James Price960aa2e2021-06-19 00:33:35 +0000106
107 auto result = std::make_unique<Result>(
108 out.data.Get<ArrayLengthFromUniform::Result>()->needs_buffer_sizes);
Ben Claytonb5cd10c2021-06-25 10:26:26 +0000109
110 builder.SetTransformApplied(this);
James Price960aa2e2021-06-19 00:33:35 +0000111 return Output{Program(std::move(builder)), std::move(result)};
James Price7a47fa82021-05-26 15:41:02 +0000112}
113
James Price830b97f2021-06-11 12:34:26 +0000114void Msl::HandleModuleScopeVariables(CloneContext& ctx) const {
James Price7a47fa82021-05-26 15:41:02 +0000115 // MSL does not allow private and workgroup variables at module-scope, so we
116 // push these declarations into the entry point function and then pass them as
117 // pointer parameters to any function that references them.
James Price830b97f2021-06-11 12:34:26 +0000118 // Similarly, texture and sampler types are converted to entry point
119 // parameters and passed by value to functions that need them.
James Price7a47fa82021-05-26 15:41:02 +0000120 //
121 // Since WGSL does not allow function-scope variables to have these storage
122 // classes, we annotate the new variable declarations with an attribute that
123 // bypasses that validation rule.
124 //
125 // Before:
126 // ```
127 // var<private> v : f32 = 2.0;
128 //
129 // fn foo() {
130 // v = v + 1.0;
131 // }
132 //
Sarahe6cb51e2021-06-29 18:39:44 +0000133 // [[stage(compute), workgroup_size(1)]]
James Price7a47fa82021-05-26 15:41:02 +0000134 // fn main() {
135 // foo();
136 // }
137 // ```
138 //
139 // After:
140 // ```
141 // fn foo(v : ptr<private, f32>) {
142 // *v = *v + 1.0;
143 // }
144 //
Sarahe6cb51e2021-06-29 18:39:44 +0000145 // [[stage(compute), workgroup_size(1)]]
James Price7a47fa82021-05-26 15:41:02 +0000146 // fn main() {
147 // var<private> v : f32 = 2.0;
James Price2940c702021-06-11 12:29:56 +0000148 // foo(&v);
James Price7a47fa82021-05-26 15:41:02 +0000149 // }
150 // ```
151
152 // Predetermine the list of function calls that need to be replaced.
153 using CallList = std::vector<const ast::CallExpression*>;
154 std::unordered_map<const ast::Function*, CallList> calls_to_replace;
155
156 std::vector<ast::Function*> functions_to_process;
157
158 // Build a list of functions that transitively reference any private or
James Price830b97f2021-06-11 12:34:26 +0000159 // workgroup variables, or texture/sampler variables.
James Price7a47fa82021-05-26 15:41:02 +0000160 for (auto* func_ast : ctx.src->AST().Functions()) {
161 auto* func_sem = ctx.src->Sem().Get(func_ast);
162
163 bool needs_processing = false;
164 for (auto* var : func_sem->ReferencedModuleVariables()) {
165 if (var->StorageClass() == ast::StorageClass::kPrivate ||
James Price830b97f2021-06-11 12:34:26 +0000166 var->StorageClass() == ast::StorageClass::kWorkgroup ||
167 var->StorageClass() == ast::StorageClass::kUniformConstant) {
James Price7a47fa82021-05-26 15:41:02 +0000168 needs_processing = true;
169 break;
170 }
171 }
172
173 if (needs_processing) {
174 functions_to_process.push_back(func_ast);
175
176 // Find all of the calls to this function that will need to be replaced.
177 for (auto* call : func_sem->CallSites()) {
178 auto* call_sem = ctx.src->Sem().Get(call);
179 calls_to_replace[call_sem->Stmt()->Function()].push_back(call);
180 }
181 }
182 }
183
James Price5c61d6d2021-08-04 19:18:38 +0000184 // Build a list of `&ident` expressions. We'll use this later to avoid
185 // generating expressions of the form `&*ident`, which break WGSL validation
186 // rules when this expression is passed to a function.
187 // TODO(jrprice): We should add support for bidirectional SEM tree traversal
188 // so that we can do this on the fly instead.
189 std::unordered_map<ast::IdentifierExpression*, ast::UnaryOpExpression*>
190 ident_to_address_of;
191 for (auto* node : ctx.src->ASTNodes().Objects()) {
192 auto* address_of = node->As<ast::UnaryOpExpression>();
193 if (!address_of || address_of->op() != ast::UnaryOp::kAddressOf) {
194 continue;
195 }
196 if (auto* ident = address_of->expr()->As<ast::IdentifierExpression>()) {
197 ident_to_address_of[ident] = address_of;
198 }
199 }
200
James Price7a47fa82021-05-26 15:41:02 +0000201 for (auto* func_ast : functions_to_process) {
202 auto* func_sem = ctx.src->Sem().Get(func_ast);
James Price2940c702021-06-11 12:29:56 +0000203 bool is_entry_point = func_ast->IsEntryPoint();
James Price7a47fa82021-05-26 15:41:02 +0000204
205 // Map module-scope variables onto their function-scope replacement.
206 std::unordered_map<const sem::Variable*, Symbol> var_to_symbol;
207
208 for (auto* var : func_sem->ReferencedModuleVariables()) {
209 if (var->StorageClass() != ast::StorageClass::kPrivate &&
James Price830b97f2021-06-11 12:34:26 +0000210 var->StorageClass() != ast::StorageClass::kWorkgroup &&
211 var->StorageClass() != ast::StorageClass::kUniformConstant) {
James Price7a47fa82021-05-26 15:41:02 +0000212 continue;
213 }
214
James Price2940c702021-06-11 12:29:56 +0000215 // This is the symbol for the variable that replaces the module-scope var.
James Price7a47fa82021-05-26 15:41:02 +0000216 auto new_var_symbol = ctx.dst->Sym();
217
Ben Clayton96a6e7e2021-07-15 22:20:29 +0000218 auto* store_type = CreateASTTypeFor(ctx, var->Type()->UnwrapRef());
James Price7a47fa82021-05-26 15:41:02 +0000219
James Price2940c702021-06-11 12:29:56 +0000220 if (is_entry_point) {
James Price830b97f2021-06-11 12:34:26 +0000221 if (store_type->is_handle()) {
222 // For a texture or sampler variable, redeclare it as an entry point
223 // parameter. Disable entry point parameter validation.
224 auto* disable_validation =
225 ctx.dst->ASTNodes().Create<ast::DisableValidationDecoration>(
226 ctx.dst->ID(), ast::DisabledValidation::kEntryPointParameter);
227 auto decos = ctx.Clone(var->Declaration()->decorations());
228 decos.push_back(disable_validation);
229 auto* param = ctx.dst->Param(new_var_symbol, store_type, decos);
230 ctx.InsertFront(func_ast->params(), param);
231 } else {
232 // For a private or workgroup variable, redeclare it at function
233 // scope. Disable storage class validation on this variable.
234 auto* disable_validation =
235 ctx.dst->ASTNodes().Create<ast::DisableValidationDecoration>(
James Price14c0b8a2021-06-24 15:53:26 +0000236 ctx.dst->ID(), ast::DisabledValidation::kIgnoreStorageClass);
James Price830b97f2021-06-11 12:34:26 +0000237 auto* constructor = ctx.Clone(var->Declaration()->constructor());
Ben Clayton75db82c2021-06-18 22:44:31 +0000238 auto* local_var = ctx.dst->Var(
239 new_var_symbol, store_type, var->StorageClass(), constructor,
240 ast::DecorationList{disable_validation});
James Price830b97f2021-06-11 12:34:26 +0000241 ctx.InsertFront(func_ast->body()->statements(),
242 ctx.dst->Decl(local_var));
243 }
James Price7a47fa82021-05-26 15:41:02 +0000244 } else {
James Price830b97f2021-06-11 12:34:26 +0000245 // For a regular function, redeclare the variable as a parameter.
246 // Use a pointer for non-handle types.
247 auto* param_type = store_type;
248 if (!store_type->is_handle()) {
249 param_type = ctx.dst->ty.pointer(param_type, var->StorageClass());
250 }
James Price7a47fa82021-05-26 15:41:02 +0000251 ctx.InsertBack(func_ast->params(),
James Price830b97f2021-06-11 12:34:26 +0000252 ctx.dst->Param(new_var_symbol, param_type));
James Price7a47fa82021-05-26 15:41:02 +0000253 }
254
James Price2940c702021-06-11 12:29:56 +0000255 // Replace all uses of the module-scope variable.
James Price830b97f2021-06-11 12:34:26 +0000256 // For non-entry points, dereference non-handle pointer parameters.
James Price7a47fa82021-05-26 15:41:02 +0000257 for (auto* user : var->Users()) {
258 if (user->Stmt()->Function() == func_ast) {
James Price2940c702021-06-11 12:29:56 +0000259 ast::Expression* expr = ctx.dst->Expr(new_var_symbol);
James Price830b97f2021-06-11 12:34:26 +0000260 if (!is_entry_point && !store_type->is_handle()) {
James Price5c61d6d2021-08-04 19:18:38 +0000261 // If this identifier is used by an address-of operator, just remove
262 // the address-of instead of adding a deref, since we already have a
263 // pointer.
264 auto* ident = user->Declaration()->As<ast::IdentifierExpression>();
265 if (ident_to_address_of.count(ident)) {
266 ctx.Replace(ident_to_address_of[ident], expr);
267 continue;
268 }
269
James Price2940c702021-06-11 12:29:56 +0000270 expr = ctx.dst->Deref(expr);
271 }
272 ctx.Replace(user->Declaration(), expr);
James Price7a47fa82021-05-26 15:41:02 +0000273 }
274 }
275
276 var_to_symbol[var] = new_var_symbol;
277 }
278
James Price2940c702021-06-11 12:29:56 +0000279 // Pass the variables as pointers to any functions that need them.
James Price7a47fa82021-05-26 15:41:02 +0000280 for (auto* call : calls_to_replace[func_ast]) {
281 auto* target = ctx.src->AST().Functions().Find(call->func()->symbol());
282 auto* target_sem = ctx.src->Sem().Get(target);
283
James Price830b97f2021-06-11 12:34:26 +0000284 // Add new arguments for any variables that are needed by the callee.
285 // For entry points, pass non-handle types as pointers.
James Price7a47fa82021-05-26 15:41:02 +0000286 for (auto* target_var : target_sem->ReferencedModuleVariables()) {
287 if (target_var->StorageClass() == ast::StorageClass::kPrivate ||
James Price830b97f2021-06-11 12:34:26 +0000288 target_var->StorageClass() == ast::StorageClass::kWorkgroup ||
289 target_var->StorageClass() == ast::StorageClass::kUniformConstant) {
James Price2940c702021-06-11 12:29:56 +0000290 ast::Expression* arg = ctx.dst->Expr(var_to_symbol[target_var]);
James Price830b97f2021-06-11 12:34:26 +0000291 if (is_entry_point && !target_var->Type()->UnwrapRef()->is_handle()) {
James Price2940c702021-06-11 12:29:56 +0000292 arg = ctx.dst->AddressOf(arg);
293 }
294 ctx.InsertBack(call->params(), arg);
James Price7a47fa82021-05-26 15:41:02 +0000295 }
296 }
297 }
298 }
299
James Price830b97f2021-06-11 12:34:26 +0000300 // Now remove all module-scope variables with these storage classes.
301 for (auto* var_ast : ctx.src->AST().GlobalVariables()) {
302 auto* var_sem = ctx.src->Sem().Get(var_ast);
303 if (var_sem->StorageClass() == ast::StorageClass::kPrivate ||
304 var_sem->StorageClass() == ast::StorageClass::kWorkgroup ||
305 var_sem->StorageClass() == ast::StorageClass::kUniformConstant) {
306 ctx.Remove(ctx.src->AST().GlobalDeclarations(), var_ast);
James Price7a47fa82021-05-26 15:41:02 +0000307 }
308 }
Ben Clayton2101c352021-02-10 21:22:03 +0000309}
310
Ben Clayton701820b2021-07-20 18:23:06 +0000311Msl::Config::Config(uint32_t buffer_size_ubo_idx,
312 uint32_t sample_mask,
313 bool disable_wi)
James Pricec32e8f62021-06-22 20:08:29 +0000314 : buffer_size_ubo_index(buffer_size_ubo_idx),
Ben Clayton701820b2021-07-20 18:23:06 +0000315 fixed_sample_mask(sample_mask),
316 disable_workgroup_init(disable_wi) {}
James Price960aa2e2021-06-19 00:33:35 +0000317Msl::Config::Config(const Config&) = default;
318Msl::Config::~Config() = default;
319
320Msl::Result::Result(bool needs_buffer_sizes)
321 : needs_storage_buffer_sizes(needs_buffer_sizes) {}
322Msl::Result::Result(const Result&) = default;
323Msl::Result::~Result() = default;
324
Ben Clayton2101c352021-02-10 21:22:03 +0000325} // namespace transform
326} // namespace tint