blob: 19b39d631088aff55c875dc466cc6c052fa9f6b3 [file] [log] [blame]
Chris Forbescc5697f2019-01-30 11:54:08 -08001// Copyright (c) 2018 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#include "source/opt/const_folding_rules.h"
16
17#include "source/opt/ir_context.h"
18
19namespace spvtools {
20namespace opt {
21namespace {
Nicolas Capens84c9c452022-11-18 14:11:05 +000022constexpr uint32_t kExtractCompositeIdInIdx = 0;
Chris Forbescc5697f2019-01-30 11:54:08 -080023
Nicolas Capens6cacf182021-11-30 11:15:46 -050024// Returns a constants with the value NaN of the given type. Only works for
25// 32-bit and 64-bit float point types. Returns |nullptr| if an error occurs.
26const analysis::Constant* GetNan(const analysis::Type* type,
27 analysis::ConstantManager* const_mgr) {
28 const analysis::Float* float_type = type->AsFloat();
29 if (float_type == nullptr) {
30 return nullptr;
31 }
32
33 switch (float_type->width()) {
34 case 32:
35 return const_mgr->GetFloatConst(std::numeric_limits<float>::quiet_NaN());
36 case 64:
37 return const_mgr->GetDoubleConst(
38 std::numeric_limits<double>::quiet_NaN());
39 default:
40 return nullptr;
41 }
42}
43
44// Returns a constants with the value INF of the given type. Only works for
45// 32-bit and 64-bit float point types. Returns |nullptr| if an error occurs.
46const analysis::Constant* GetInf(const analysis::Type* type,
47 analysis::ConstantManager* const_mgr) {
48 const analysis::Float* float_type = type->AsFloat();
49 if (float_type == nullptr) {
50 return nullptr;
51 }
52
53 switch (float_type->width()) {
54 case 32:
55 return const_mgr->GetFloatConst(std::numeric_limits<float>::infinity());
56 case 64:
57 return const_mgr->GetDoubleConst(std::numeric_limits<double>::infinity());
58 default:
59 return nullptr;
60 }
61}
62
Chris Forbescc5697f2019-01-30 11:54:08 -080063// Returns true if |type| is Float or a vector of Float.
64bool HasFloatingPoint(const analysis::Type* type) {
65 if (type->AsFloat()) {
66 return true;
67 } else if (const analysis::Vector* vec_type = type->AsVector()) {
68 return vec_type->element_type()->AsFloat() != nullptr;
69 }
70
71 return false;
72}
73
Nicolas Capens6cacf182021-11-30 11:15:46 -050074// Returns a constants with the value |-val| of the given type. Only works for
75// 32-bit and 64-bit float point types. Returns |nullptr| if an error occurs.
Nicolas Capens84c9c452022-11-18 14:11:05 +000076const analysis::Constant* NegateFPConst(const analysis::Type* result_type,
Nicolas Capens6cacf182021-11-30 11:15:46 -050077 const analysis::Constant* val,
78 analysis::ConstantManager* const_mgr) {
79 const analysis::Float* float_type = result_type->AsFloat();
80 assert(float_type != nullptr);
81 if (float_type->width() == 32) {
82 float fa = val->GetFloat();
83 return const_mgr->GetFloatConst(-fa);
84 } else if (float_type->width() == 64) {
85 double da = val->GetDouble();
86 return const_mgr->GetDoubleConst(-da);
87 }
88 return nullptr;
89}
90
Chris Forbescc5697f2019-01-30 11:54:08 -080091// Folds an OpcompositeExtract where input is a composite constant.
92ConstantFoldingRule FoldExtractWithConstants() {
93 return [](IRContext* context, Instruction* inst,
94 const std::vector<const analysis::Constant*>& constants)
95 -> const analysis::Constant* {
96 const analysis::Constant* c = constants[kExtractCompositeIdInIdx];
97 if (c == nullptr) {
98 return nullptr;
99 }
100
101 for (uint32_t i = 1; i < inst->NumInOperands(); ++i) {
102 uint32_t element_index = inst->GetSingleWordInOperand(i);
103 if (c->AsNullConstant()) {
104 // Return Null for the return type.
105 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
106 analysis::TypeManager* type_mgr = context->get_type_mgr();
107 return const_mgr->GetConstant(type_mgr->GetType(inst->type_id()), {});
108 }
109
110 auto cc = c->AsCompositeConstant();
111 assert(cc != nullptr);
112 auto components = cc->GetComponents();
Ben Claytond0f684e2019-08-30 22:36:08 +0100113 // Protect against invalid IR. Refuse to fold if the index is out
114 // of bounds.
115 if (element_index >= components.size()) return nullptr;
Chris Forbescc5697f2019-01-30 11:54:08 -0800116 c = components[element_index];
117 }
118 return c;
119 };
120}
121
Nicolas Capens84c9c452022-11-18 14:11:05 +0000122// Folds an OpcompositeInsert where input is a composite constant.
123ConstantFoldingRule FoldInsertWithConstants() {
124 return [](IRContext* context, Instruction* inst,
125 const std::vector<const analysis::Constant*>& constants)
126 -> const analysis::Constant* {
127 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
128 const analysis::Constant* object = constants[0];
129 const analysis::Constant* composite = constants[1];
130 if (object == nullptr || composite == nullptr) {
131 return nullptr;
132 }
133
134 // If there is more than 1 index, then each additional constant used by the
135 // index will need to be recreated to use the inserted object.
136 std::vector<const analysis::Constant*> chain;
137 std::vector<const analysis::Constant*> components;
138 const analysis::Type* type = nullptr;
139
140 // Work down hierarchy and add all the indexes, not including the final
141 // index.
142 for (uint32_t i = 2; i < inst->NumInOperands(); ++i) {
Alexis Hetu1ef51fa2022-11-24 09:03:10 -0500143 if (composite->AsNullConstant()) {
144 // Return Null for the return type.
145 analysis::TypeManager* type_mgr = context->get_type_mgr();
146 return const_mgr->GetConstant(type_mgr->GetType(inst->type_id()), {});
147 }
148
Nicolas Capens84c9c452022-11-18 14:11:05 +0000149 if (i != inst->NumInOperands() - 1) {
150 chain.push_back(composite);
151 }
152 const uint32_t index = inst->GetSingleWordInOperand(i);
153 components = composite->AsCompositeConstant()->GetComponents();
154 type = composite->AsCompositeConstant()->type();
155 composite = components[index];
156 }
157
158 // Final index in hierarchy is inserted with new object.
159 const uint32_t final_index =
160 inst->GetSingleWordInOperand(inst->NumInOperands() - 1);
161 std::vector<uint32_t> ids;
162 for (size_t i = 0; i < components.size(); i++) {
163 const analysis::Constant* constant =
164 (i == final_index) ? object : components[i];
165 Instruction* member_inst = const_mgr->GetDefiningInstruction(constant);
166 ids.push_back(member_inst->result_id());
167 }
168 const analysis::Constant* new_constant = const_mgr->GetConstant(type, ids);
169
170 // Work backwards up the chain and replace each index with new constant.
171 for (size_t i = chain.size(); i > 0; i--) {
172 // Need to insert any previous instruction into the module first.
173 // Can't just insert in types_values_begin() because it will move above
174 // where the types are declared
175 for (Module::inst_iterator inst_iter = context->types_values_begin();
176 inst_iter != context->types_values_end(); ++inst_iter) {
177 Instruction* x = &*inst_iter;
178 if (inst->result_id() == x->result_id()) {
179 const_mgr->BuildInstructionAndAddToModule(new_constant, &inst_iter);
180 break;
181 }
182 }
183
184 composite = chain[i - 1];
185 components = composite->AsCompositeConstant()->GetComponents();
186 type = composite->AsCompositeConstant()->type();
187 ids.clear();
188 for (size_t k = 0; k < components.size(); k++) {
189 const uint32_t index =
190 inst->GetSingleWordInOperand(1 + static_cast<uint32_t>(i));
191 const analysis::Constant* constant =
192 (k == index) ? new_constant : components[k];
193 const uint32_t constant_id =
194 const_mgr->FindDeclaredConstant(constant, 0);
195 ids.push_back(constant_id);
196 }
197 new_constant = const_mgr->GetConstant(type, ids);
198 }
199
200 // If multiple constants were created, only need to return the top index.
201 return new_constant;
202 };
203}
204
Chris Forbescc5697f2019-01-30 11:54:08 -0800205ConstantFoldingRule FoldVectorShuffleWithConstants() {
206 return [](IRContext* context, Instruction* inst,
207 const std::vector<const analysis::Constant*>& constants)
208 -> const analysis::Constant* {
Nicolas Capens84c9c452022-11-18 14:11:05 +0000209 assert(inst->opcode() == spv::Op::OpVectorShuffle);
Chris Forbescc5697f2019-01-30 11:54:08 -0800210 const analysis::Constant* c1 = constants[0];
211 const analysis::Constant* c2 = constants[1];
212 if (c1 == nullptr || c2 == nullptr) {
213 return nullptr;
214 }
215
216 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
217 const analysis::Type* element_type = c1->type()->AsVector()->element_type();
218
219 std::vector<const analysis::Constant*> c1_components;
220 if (const analysis::VectorConstant* vec_const = c1->AsVectorConstant()) {
221 c1_components = vec_const->GetComponents();
222 } else {
223 assert(c1->AsNullConstant());
224 const analysis::Constant* element =
225 const_mgr->GetConstant(element_type, {});
226 c1_components.resize(c1->type()->AsVector()->element_count(), element);
227 }
228 std::vector<const analysis::Constant*> c2_components;
229 if (const analysis::VectorConstant* vec_const = c2->AsVectorConstant()) {
230 c2_components = vec_const->GetComponents();
231 } else {
232 assert(c2->AsNullConstant());
233 const analysis::Constant* element =
234 const_mgr->GetConstant(element_type, {});
235 c2_components.resize(c2->type()->AsVector()->element_count(), element);
236 }
237
238 std::vector<uint32_t> ids;
239 const uint32_t undef_literal_value = 0xffffffff;
240 for (uint32_t i = 2; i < inst->NumInOperands(); ++i) {
241 uint32_t index = inst->GetSingleWordInOperand(i);
242 if (index == undef_literal_value) {
243 // Don't fold shuffle with undef literal value.
244 return nullptr;
245 } else if (index < c1_components.size()) {
246 Instruction* member_inst =
247 const_mgr->GetDefiningInstruction(c1_components[index]);
248 ids.push_back(member_inst->result_id());
249 } else {
250 Instruction* member_inst = const_mgr->GetDefiningInstruction(
251 c2_components[index - c1_components.size()]);
252 ids.push_back(member_inst->result_id());
253 }
254 }
255
256 analysis::TypeManager* type_mgr = context->get_type_mgr();
257 return const_mgr->GetConstant(type_mgr->GetType(inst->type_id()), ids);
258 };
259}
260
261ConstantFoldingRule FoldVectorTimesScalar() {
262 return [](IRContext* context, Instruction* inst,
263 const std::vector<const analysis::Constant*>& constants)
264 -> const analysis::Constant* {
Nicolas Capens84c9c452022-11-18 14:11:05 +0000265 assert(inst->opcode() == spv::Op::OpVectorTimesScalar);
Chris Forbescc5697f2019-01-30 11:54:08 -0800266 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
267 analysis::TypeManager* type_mgr = context->get_type_mgr();
268
269 if (!inst->IsFloatingPointFoldingAllowed()) {
270 if (HasFloatingPoint(type_mgr->GetType(inst->type_id()))) {
271 return nullptr;
272 }
273 }
274
275 const analysis::Constant* c1 = constants[0];
276 const analysis::Constant* c2 = constants[1];
277
278 if (c1 && c1->IsZero()) {
279 return c1;
280 }
281
282 if (c2 && c2->IsZero()) {
283 // Get or create the NullConstant for this type.
284 std::vector<uint32_t> ids;
285 return const_mgr->GetConstant(type_mgr->GetType(inst->type_id()), ids);
286 }
287
288 if (c1 == nullptr || c2 == nullptr) {
289 return nullptr;
290 }
291
292 // Check result type.
293 const analysis::Type* result_type = type_mgr->GetType(inst->type_id());
294 const analysis::Vector* vector_type = result_type->AsVector();
295 assert(vector_type != nullptr);
296 const analysis::Type* element_type = vector_type->element_type();
297 assert(element_type != nullptr);
298 const analysis::Float* float_type = element_type->AsFloat();
299 assert(float_type != nullptr);
300
301 // Check types of c1 and c2.
302 assert(c1->type()->AsVector() == vector_type);
303 assert(c1->type()->AsVector()->element_type() == element_type &&
304 c2->type() == element_type);
305
306 // Get a float vector that is the result of vector-times-scalar.
307 std::vector<const analysis::Constant*> c1_components =
308 c1->GetVectorComponents(const_mgr);
309 std::vector<uint32_t> ids;
310 if (float_type->width() == 32) {
311 float scalar = c2->GetFloat();
312 for (uint32_t i = 0; i < c1_components.size(); ++i) {
313 utils::FloatProxy<float> result(c1_components[i]->GetFloat() * scalar);
314 std::vector<uint32_t> words = result.GetWords();
315 const analysis::Constant* new_elem =
316 const_mgr->GetConstant(float_type, words);
317 ids.push_back(const_mgr->GetDefiningInstruction(new_elem)->result_id());
318 }
319 return const_mgr->GetConstant(vector_type, ids);
320 } else if (float_type->width() == 64) {
321 double scalar = c2->GetDouble();
322 for (uint32_t i = 0; i < c1_components.size(); ++i) {
323 utils::FloatProxy<double> result(c1_components[i]->GetDouble() *
324 scalar);
325 std::vector<uint32_t> words = result.GetWords();
326 const analysis::Constant* new_elem =
327 const_mgr->GetConstant(float_type, words);
328 ids.push_back(const_mgr->GetDefiningInstruction(new_elem)->result_id());
329 }
330 return const_mgr->GetConstant(vector_type, ids);
331 }
332 return nullptr;
333 };
334}
335
Nicolas Capens00a1bcc2022-07-29 16:49:40 -0400336ConstantFoldingRule FoldVectorTimesMatrix() {
337 return [](IRContext* context, Instruction* inst,
338 const std::vector<const analysis::Constant*>& constants)
339 -> const analysis::Constant* {
Nicolas Capens84c9c452022-11-18 14:11:05 +0000340 assert(inst->opcode() == spv::Op::OpVectorTimesMatrix);
Nicolas Capens00a1bcc2022-07-29 16:49:40 -0400341 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
342 analysis::TypeManager* type_mgr = context->get_type_mgr();
343
344 if (!inst->IsFloatingPointFoldingAllowed()) {
345 if (HasFloatingPoint(type_mgr->GetType(inst->type_id()))) {
346 return nullptr;
347 }
348 }
349
350 const analysis::Constant* c1 = constants[0];
351 const analysis::Constant* c2 = constants[1];
352
353 if (c1 == nullptr || c2 == nullptr) {
354 return nullptr;
355 }
356
357 // Check result type.
358 const analysis::Type* result_type = type_mgr->GetType(inst->type_id());
359 const analysis::Vector* vector_type = result_type->AsVector();
360 assert(vector_type != nullptr);
361 const analysis::Type* element_type = vector_type->element_type();
362 assert(element_type != nullptr);
363 const analysis::Float* float_type = element_type->AsFloat();
364 assert(float_type != nullptr);
365
366 // Check types of c1 and c2.
367 assert(c1->type()->AsVector() == vector_type);
368 assert(c1->type()->AsVector()->element_type() == element_type &&
369 c2->type()->AsMatrix()->element_type() == vector_type);
370
371 // Get a float vector that is the result of vector-times-matrix.
372 std::vector<const analysis::Constant*> c1_components =
373 c1->GetVectorComponents(const_mgr);
374 std::vector<const analysis::Constant*> c2_components =
375 c2->AsMatrixConstant()->GetComponents();
376 uint32_t resultVectorSize = result_type->AsVector()->element_count();
377
378 std::vector<uint32_t> ids;
379
380 if ((c1 && c1->IsZero()) || (c2 && c2->IsZero())) {
381 std::vector<uint32_t> words(float_type->width() / 32, 0);
382 for (uint32_t i = 0; i < resultVectorSize; ++i) {
383 const analysis::Constant* new_elem =
384 const_mgr->GetConstant(float_type, words);
385 ids.push_back(const_mgr->GetDefiningInstruction(new_elem)->result_id());
386 }
387 return const_mgr->GetConstant(vector_type, ids);
388 }
389
390 if (float_type->width() == 32) {
391 for (uint32_t i = 0; i < resultVectorSize; ++i) {
392 float result_scalar = 0.0f;
393 const analysis::VectorConstant* c2_vec =
394 c2_components[i]->AsVectorConstant();
395 for (uint32_t j = 0; j < c2_vec->GetComponents().size(); ++j) {
396 float c1_scalar = c1_components[j]->GetFloat();
397 float c2_scalar = c2_vec->GetComponents()[j]->GetFloat();
398 result_scalar += c1_scalar * c2_scalar;
399 }
400 utils::FloatProxy<float> result(result_scalar);
401 std::vector<uint32_t> words = result.GetWords();
402 const analysis::Constant* new_elem =
403 const_mgr->GetConstant(float_type, words);
404 ids.push_back(const_mgr->GetDefiningInstruction(new_elem)->result_id());
405 }
406 return const_mgr->GetConstant(vector_type, ids);
407 } else if (float_type->width() == 64) {
408 for (uint32_t i = 0; i < c2_components.size(); ++i) {
409 double result_scalar = 0.0;
410 const analysis::VectorConstant* c2_vec =
411 c2_components[i]->AsVectorConstant();
412 for (uint32_t j = 0; j < c2_vec->GetComponents().size(); ++j) {
413 double c1_scalar = c1_components[j]->GetDouble();
414 double c2_scalar = c2_vec->GetComponents()[j]->GetDouble();
415 result_scalar += c1_scalar * c2_scalar;
416 }
417 utils::FloatProxy<double> result(result_scalar);
418 std::vector<uint32_t> words = result.GetWords();
419 const analysis::Constant* new_elem =
420 const_mgr->GetConstant(float_type, words);
421 ids.push_back(const_mgr->GetDefiningInstruction(new_elem)->result_id());
422 }
423 return const_mgr->GetConstant(vector_type, ids);
424 }
425 return nullptr;
426 };
427}
428
429ConstantFoldingRule FoldMatrixTimesVector() {
430 return [](IRContext* context, Instruction* inst,
431 const std::vector<const analysis::Constant*>& constants)
432 -> const analysis::Constant* {
Nicolas Capens84c9c452022-11-18 14:11:05 +0000433 assert(inst->opcode() == spv::Op::OpMatrixTimesVector);
Nicolas Capens00a1bcc2022-07-29 16:49:40 -0400434 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
435 analysis::TypeManager* type_mgr = context->get_type_mgr();
436
437 if (!inst->IsFloatingPointFoldingAllowed()) {
438 if (HasFloatingPoint(type_mgr->GetType(inst->type_id()))) {
439 return nullptr;
440 }
441 }
442
443 const analysis::Constant* c1 = constants[0];
444 const analysis::Constant* c2 = constants[1];
445
446 if (c1 == nullptr || c2 == nullptr) {
447 return nullptr;
448 }
449
450 // Check result type.
451 const analysis::Type* result_type = type_mgr->GetType(inst->type_id());
452 const analysis::Vector* vector_type = result_type->AsVector();
453 assert(vector_type != nullptr);
454 const analysis::Type* element_type = vector_type->element_type();
455 assert(element_type != nullptr);
456 const analysis::Float* float_type = element_type->AsFloat();
457 assert(float_type != nullptr);
458
459 // Check types of c1 and c2.
460 assert(c1->type()->AsMatrix()->element_type() == vector_type);
461 assert(c2->type()->AsVector()->element_type() == element_type);
462
463 // Get a float vector that is the result of matrix-times-vector.
464 std::vector<const analysis::Constant*> c1_components =
465 c1->AsMatrixConstant()->GetComponents();
466 std::vector<const analysis::Constant*> c2_components =
467 c2->GetVectorComponents(const_mgr);
468 uint32_t resultVectorSize = result_type->AsVector()->element_count();
469
470 std::vector<uint32_t> ids;
471
472 if ((c1 && c1->IsZero()) || (c2 && c2->IsZero())) {
473 std::vector<uint32_t> words(float_type->width() / 32, 0);
474 for (uint32_t i = 0; i < resultVectorSize; ++i) {
475 const analysis::Constant* new_elem =
476 const_mgr->GetConstant(float_type, words);
477 ids.push_back(const_mgr->GetDefiningInstruction(new_elem)->result_id());
478 }
479 return const_mgr->GetConstant(vector_type, ids);
480 }
481
482 if (float_type->width() == 32) {
483 for (uint32_t i = 0; i < resultVectorSize; ++i) {
484 float result_scalar = 0.0f;
485 for (uint32_t j = 0; j < c1_components.size(); ++j) {
486 float c1_scalar = c1_components[j]
487 ->AsVectorConstant()
488 ->GetComponents()[i]
489 ->GetFloat();
490 float c2_scalar = c2_components[j]->GetFloat();
491 result_scalar += c1_scalar * c2_scalar;
492 }
493 utils::FloatProxy<float> result(result_scalar);
494 std::vector<uint32_t> words = result.GetWords();
495 const analysis::Constant* new_elem =
496 const_mgr->GetConstant(float_type, words);
497 ids.push_back(const_mgr->GetDefiningInstruction(new_elem)->result_id());
498 }
499 return const_mgr->GetConstant(vector_type, ids);
500 } else if (float_type->width() == 64) {
501 for (uint32_t i = 0; i < resultVectorSize; ++i) {
502 double result_scalar = 0.0;
503 for (uint32_t j = 0; j < c1_components.size(); ++j) {
504 double c1_scalar = c1_components[j]
505 ->AsVectorConstant()
506 ->GetComponents()[i]
507 ->GetDouble();
508 double c2_scalar = c2_components[j]->GetDouble();
509 result_scalar += c1_scalar * c2_scalar;
510 }
511 utils::FloatProxy<double> result(result_scalar);
512 std::vector<uint32_t> words = result.GetWords();
513 const analysis::Constant* new_elem =
514 const_mgr->GetConstant(float_type, words);
515 ids.push_back(const_mgr->GetDefiningInstruction(new_elem)->result_id());
516 }
517 return const_mgr->GetConstant(vector_type, ids);
518 }
519 return nullptr;
520 };
521}
522
Chris Forbescc5697f2019-01-30 11:54:08 -0800523ConstantFoldingRule FoldCompositeWithConstants() {
524 // Folds an OpCompositeConstruct where all of the inputs are constants to a
525 // constant. A new constant is created if necessary.
526 return [](IRContext* context, Instruction* inst,
527 const std::vector<const analysis::Constant*>& constants)
528 -> const analysis::Constant* {
529 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
530 analysis::TypeManager* type_mgr = context->get_type_mgr();
531 const analysis::Type* new_type = type_mgr->GetType(inst->type_id());
532 Instruction* type_inst =
533 context->get_def_use_mgr()->GetDef(inst->type_id());
534
535 std::vector<uint32_t> ids;
536 for (uint32_t i = 0; i < constants.size(); ++i) {
537 const analysis::Constant* element_const = constants[i];
538 if (element_const == nullptr) {
539 return nullptr;
540 }
541
542 uint32_t component_type_id = 0;
Nicolas Capens84c9c452022-11-18 14:11:05 +0000543 if (type_inst->opcode() == spv::Op::OpTypeStruct) {
Chris Forbescc5697f2019-01-30 11:54:08 -0800544 component_type_id = type_inst->GetSingleWordInOperand(i);
Nicolas Capens84c9c452022-11-18 14:11:05 +0000545 } else if (type_inst->opcode() == spv::Op::OpTypeArray) {
Chris Forbescc5697f2019-01-30 11:54:08 -0800546 component_type_id = type_inst->GetSingleWordInOperand(0);
547 }
548
549 uint32_t element_id =
550 const_mgr->FindDeclaredConstant(element_const, component_type_id);
551 if (element_id == 0) {
552 return nullptr;
553 }
554 ids.push_back(element_id);
555 }
556 return const_mgr->GetConstant(new_type, ids);
557 };
558}
559
560// The interface for a function that returns the result of applying a scalar
561// floating-point binary operation on |a| and |b|. The type of the return value
562// will be |type|. The input constants must also be of type |type|.
563using UnaryScalarFoldingRule = std::function<const analysis::Constant*(
564 const analysis::Type* result_type, const analysis::Constant* a,
565 analysis::ConstantManager*)>;
566
567// The interface for a function that returns the result of applying a scalar
568// floating-point binary operation on |a| and |b|. The type of the return value
569// will be |type|. The input constants must also be of type |type|.
570using BinaryScalarFoldingRule = std::function<const analysis::Constant*(
571 const analysis::Type* result_type, const analysis::Constant* a,
572 const analysis::Constant* b, analysis::ConstantManager*)>;
573
574// Returns a |ConstantFoldingRule| that folds unary floating point scalar ops
575// using |scalar_rule| and unary float point vectors ops by applying
576// |scalar_rule| to the elements of the vector. The |ConstantFoldingRule|
577// that is returned assumes that |constants| contains 1 entry. If they are
578// not |nullptr|, then their type is either |Float| or |Integer| or a |Vector|
579// whose element type is |Float| or |Integer|.
580ConstantFoldingRule FoldFPUnaryOp(UnaryScalarFoldingRule scalar_rule) {
581 return [scalar_rule](IRContext* context, Instruction* inst,
582 const std::vector<const analysis::Constant*>& constants)
583 -> const analysis::Constant* {
584 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
585 analysis::TypeManager* type_mgr = context->get_type_mgr();
586 const analysis::Type* result_type = type_mgr->GetType(inst->type_id());
587 const analysis::Vector* vector_type = result_type->AsVector();
588
589 if (!inst->IsFloatingPointFoldingAllowed()) {
590 return nullptr;
591 }
592
Ben Claytondc6b76a2020-02-24 14:53:40 +0000593 const analysis::Constant* arg =
Nicolas Capens84c9c452022-11-18 14:11:05 +0000594 (inst->opcode() == spv::Op::OpExtInst) ? constants[1] : constants[0];
Ben Claytondc6b76a2020-02-24 14:53:40 +0000595
596 if (arg == nullptr) {
Chris Forbescc5697f2019-01-30 11:54:08 -0800597 return nullptr;
598 }
599
600 if (vector_type != nullptr) {
601 std::vector<const analysis::Constant*> a_components;
602 std::vector<const analysis::Constant*> results_components;
603
Ben Claytondc6b76a2020-02-24 14:53:40 +0000604 a_components = arg->GetVectorComponents(const_mgr);
Chris Forbescc5697f2019-01-30 11:54:08 -0800605
606 // Fold each component of the vector.
607 for (uint32_t i = 0; i < a_components.size(); ++i) {
608 results_components.push_back(scalar_rule(vector_type->element_type(),
609 a_components[i], const_mgr));
610 if (results_components[i] == nullptr) {
611 return nullptr;
612 }
613 }
614
615 // Build the constant object and return it.
616 std::vector<uint32_t> ids;
617 for (const analysis::Constant* member : results_components) {
618 ids.push_back(const_mgr->GetDefiningInstruction(member)->result_id());
619 }
620 return const_mgr->GetConstant(vector_type, ids);
621 } else {
Ben Claytondc6b76a2020-02-24 14:53:40 +0000622 return scalar_rule(result_type, arg, const_mgr);
Chris Forbescc5697f2019-01-30 11:54:08 -0800623 }
624 };
625}
626
Ben Claytond552f632019-11-18 11:18:41 +0000627// Returns the result of folding the constants in |constants| according the
628// |scalar_rule|. If |result_type| is a vector, then |scalar_rule| is applied
629// per component.
630const analysis::Constant* FoldFPBinaryOp(
631 BinaryScalarFoldingRule scalar_rule, uint32_t result_type_id,
632 const std::vector<const analysis::Constant*>& constants,
633 IRContext* context) {
634 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
635 analysis::TypeManager* type_mgr = context->get_type_mgr();
636 const analysis::Type* result_type = type_mgr->GetType(result_type_id);
637 const analysis::Vector* vector_type = result_type->AsVector();
638
639 if (constants[0] == nullptr || constants[1] == nullptr) {
640 return nullptr;
641 }
642
643 if (vector_type != nullptr) {
644 std::vector<const analysis::Constant*> a_components;
645 std::vector<const analysis::Constant*> b_components;
646 std::vector<const analysis::Constant*> results_components;
647
648 a_components = constants[0]->GetVectorComponents(const_mgr);
649 b_components = constants[1]->GetVectorComponents(const_mgr);
650
651 // Fold each component of the vector.
652 for (uint32_t i = 0; i < a_components.size(); ++i) {
653 results_components.push_back(scalar_rule(vector_type->element_type(),
654 a_components[i], b_components[i],
655 const_mgr));
656 if (results_components[i] == nullptr) {
657 return nullptr;
658 }
659 }
660
661 // Build the constant object and return it.
662 std::vector<uint32_t> ids;
663 for (const analysis::Constant* member : results_components) {
664 ids.push_back(const_mgr->GetDefiningInstruction(member)->result_id());
665 }
666 return const_mgr->GetConstant(vector_type, ids);
667 } else {
668 return scalar_rule(result_type, constants[0], constants[1], const_mgr);
669 }
670}
671
Chris Forbescc5697f2019-01-30 11:54:08 -0800672// Returns a |ConstantFoldingRule| that folds floating point scalars using
673// |scalar_rule| and vectors of floating point by applying |scalar_rule| to the
674// elements of the vector. The |ConstantFoldingRule| that is returned assumes
675// that |constants| contains 2 entries. If they are not |nullptr|, then their
676// type is either |Float| or a |Vector| whose element type is |Float|.
677ConstantFoldingRule FoldFPBinaryOp(BinaryScalarFoldingRule scalar_rule) {
678 return [scalar_rule](IRContext* context, Instruction* inst,
679 const std::vector<const analysis::Constant*>& constants)
680 -> const analysis::Constant* {
Chris Forbescc5697f2019-01-30 11:54:08 -0800681 if (!inst->IsFloatingPointFoldingAllowed()) {
682 return nullptr;
683 }
Nicolas Capens84c9c452022-11-18 14:11:05 +0000684 if (inst->opcode() == spv::Op::OpExtInst) {
Ben Claytond552f632019-11-18 11:18:41 +0000685 return FoldFPBinaryOp(scalar_rule, inst->type_id(),
686 {constants[1], constants[2]}, context);
Chris Forbescc5697f2019-01-30 11:54:08 -0800687 }
Ben Claytond552f632019-11-18 11:18:41 +0000688 return FoldFPBinaryOp(scalar_rule, inst->type_id(), constants, context);
Chris Forbescc5697f2019-01-30 11:54:08 -0800689 };
690}
691
692// This macro defines a |UnaryScalarFoldingRule| that performs float to
693// integer conversion.
694// TODO(greg-lunarg): Support for 64-bit integer types.
695UnaryScalarFoldingRule FoldFToIOp() {
696 return [](const analysis::Type* result_type, const analysis::Constant* a,
697 analysis::ConstantManager* const_mgr) -> const analysis::Constant* {
698 assert(result_type != nullptr && a != nullptr);
699 const analysis::Integer* integer_type = result_type->AsInteger();
700 const analysis::Float* float_type = a->type()->AsFloat();
701 assert(float_type != nullptr);
702 assert(integer_type != nullptr);
703 if (integer_type->width() != 32) return nullptr;
704 if (float_type->width() == 32) {
705 float fa = a->GetFloat();
706 uint32_t result = integer_type->IsSigned()
707 ? static_cast<uint32_t>(static_cast<int32_t>(fa))
708 : static_cast<uint32_t>(fa);
709 std::vector<uint32_t> words = {result};
710 return const_mgr->GetConstant(result_type, words);
711 } else if (float_type->width() == 64) {
712 double fa = a->GetDouble();
713 uint32_t result = integer_type->IsSigned()
714 ? static_cast<uint32_t>(static_cast<int32_t>(fa))
715 : static_cast<uint32_t>(fa);
716 std::vector<uint32_t> words = {result};
717 return const_mgr->GetConstant(result_type, words);
718 }
719 return nullptr;
720 };
721}
722
723// This function defines a |UnaryScalarFoldingRule| that performs integer to
724// float conversion.
725// TODO(greg-lunarg): Support for 64-bit integer types.
726UnaryScalarFoldingRule FoldIToFOp() {
727 return [](const analysis::Type* result_type, const analysis::Constant* a,
728 analysis::ConstantManager* const_mgr) -> const analysis::Constant* {
729 assert(result_type != nullptr && a != nullptr);
730 const analysis::Integer* integer_type = a->type()->AsInteger();
731 const analysis::Float* float_type = result_type->AsFloat();
732 assert(float_type != nullptr);
733 assert(integer_type != nullptr);
734 if (integer_type->width() != 32) return nullptr;
735 uint32_t ua = a->GetU32();
736 if (float_type->width() == 32) {
737 float result_val = integer_type->IsSigned()
738 ? static_cast<float>(static_cast<int32_t>(ua))
739 : static_cast<float>(ua);
740 utils::FloatProxy<float> result(result_val);
741 std::vector<uint32_t> words = {result.data()};
742 return const_mgr->GetConstant(result_type, words);
743 } else if (float_type->width() == 64) {
744 double result_val = integer_type->IsSigned()
745 ? static_cast<double>(static_cast<int32_t>(ua))
746 : static_cast<double>(ua);
747 utils::FloatProxy<double> result(result_val);
748 std::vector<uint32_t> words = result.GetWords();
749 return const_mgr->GetConstant(result_type, words);
750 }
751 return nullptr;
752 };
753}
754
Ben Claytonb73b7602019-07-29 13:56:13 +0100755// This defines a |UnaryScalarFoldingRule| that performs |OpQuantizeToF16|.
756UnaryScalarFoldingRule FoldQuantizeToF16Scalar() {
757 return [](const analysis::Type* result_type, const analysis::Constant* a,
758 analysis::ConstantManager* const_mgr) -> const analysis::Constant* {
759 assert(result_type != nullptr && a != nullptr);
760 const analysis::Float* float_type = a->type()->AsFloat();
761 assert(float_type != nullptr);
762 if (float_type->width() != 32) {
763 return nullptr;
764 }
765
766 float fa = a->GetFloat();
767 utils::HexFloat<utils::FloatProxy<float>> orignal(fa);
768 utils::HexFloat<utils::FloatProxy<utils::Float16>> quantized(0);
769 utils::HexFloat<utils::FloatProxy<float>> result(0.0f);
770 orignal.castTo(quantized, utils::round_direction::kToZero);
771 quantized.castTo(result, utils::round_direction::kToZero);
772 std::vector<uint32_t> words = {result.getBits()};
773 return const_mgr->GetConstant(result_type, words);
774 };
775}
776
Chris Forbescc5697f2019-01-30 11:54:08 -0800777// This macro defines a |BinaryScalarFoldingRule| that applies |op|. The
778// operator |op| must work for both float and double, and use syntax "f1 op f2".
Ben Claytond552f632019-11-18 11:18:41 +0000779#define FOLD_FPARITH_OP(op) \
780 [](const analysis::Type* result_type_in_macro, const analysis::Constant* a, \
781 const analysis::Constant* b, \
782 analysis::ConstantManager* const_mgr_in_macro) \
783 -> const analysis::Constant* { \
784 assert(result_type_in_macro != nullptr && a != nullptr && b != nullptr); \
785 assert(result_type_in_macro == a->type() && \
786 result_type_in_macro == b->type()); \
787 const analysis::Float* float_type_in_macro = \
788 result_type_in_macro->AsFloat(); \
789 assert(float_type_in_macro != nullptr); \
790 if (float_type_in_macro->width() == 32) { \
791 float fa = a->GetFloat(); \
792 float fb = b->GetFloat(); \
793 utils::FloatProxy<float> result_in_macro(fa op fb); \
794 std::vector<uint32_t> words_in_macro = result_in_macro.GetWords(); \
795 return const_mgr_in_macro->GetConstant(result_type_in_macro, \
796 words_in_macro); \
797 } else if (float_type_in_macro->width() == 64) { \
798 double fa = a->GetDouble(); \
799 double fb = b->GetDouble(); \
800 utils::FloatProxy<double> result_in_macro(fa op fb); \
801 std::vector<uint32_t> words_in_macro = result_in_macro.GetWords(); \
802 return const_mgr_in_macro->GetConstant(result_type_in_macro, \
803 words_in_macro); \
804 } \
805 return nullptr; \
Chris Forbescc5697f2019-01-30 11:54:08 -0800806 }
807
808// Define the folding rule for conversion between floating point and integer
809ConstantFoldingRule FoldFToI() { return FoldFPUnaryOp(FoldFToIOp()); }
810ConstantFoldingRule FoldIToF() { return FoldFPUnaryOp(FoldIToFOp()); }
Ben Claytonb73b7602019-07-29 13:56:13 +0100811ConstantFoldingRule FoldQuantizeToF16() {
812 return FoldFPUnaryOp(FoldQuantizeToF16Scalar());
813}
Chris Forbescc5697f2019-01-30 11:54:08 -0800814
815// Define the folding rules for subtraction, addition, multiplication, and
816// division for floating point values.
817ConstantFoldingRule FoldFSub() { return FoldFPBinaryOp(FOLD_FPARITH_OP(-)); }
818ConstantFoldingRule FoldFAdd() { return FoldFPBinaryOp(FOLD_FPARITH_OP(+)); }
819ConstantFoldingRule FoldFMul() { return FoldFPBinaryOp(FOLD_FPARITH_OP(*)); }
Nicolas Capens6cacf182021-11-30 11:15:46 -0500820
821// Returns the constant that results from evaluating |numerator| / 0.0. Returns
sugoi1b398bf32022-02-18 10:27:28 -0500822// |nullptr| if the result could not be evaluated.
Nicolas Capens6cacf182021-11-30 11:15:46 -0500823const analysis::Constant* FoldFPScalarDivideByZero(
824 const analysis::Type* result_type, const analysis::Constant* numerator,
825 analysis::ConstantManager* const_mgr) {
826 if (numerator == nullptr) {
827 return nullptr;
828 }
829
830 if (numerator->IsZero()) {
831 return GetNan(result_type, const_mgr);
832 }
833
834 const analysis::Constant* result = GetInf(result_type, const_mgr);
835 if (result == nullptr) {
836 return nullptr;
837 }
838
839 if (numerator->AsFloatConstant()->GetValueAsDouble() < 0.0) {
Nicolas Capens84c9c452022-11-18 14:11:05 +0000840 result = NegateFPConst(result_type, result, const_mgr);
Nicolas Capens6cacf182021-11-30 11:15:46 -0500841 }
842 return result;
843}
844
845// Returns the result of folding |numerator| / |denominator|. Returns |nullptr|
846// if it cannot be folded.
847const analysis::Constant* FoldScalarFPDivide(
848 const analysis::Type* result_type, const analysis::Constant* numerator,
849 const analysis::Constant* denominator,
850 analysis::ConstantManager* const_mgr) {
851 if (denominator == nullptr) {
852 return nullptr;
853 }
854
855 if (denominator->IsZero()) {
856 return FoldFPScalarDivideByZero(result_type, numerator, const_mgr);
857 }
858
859 const analysis::FloatConstant* denominator_float =
860 denominator->AsFloatConstant();
861 if (denominator_float && denominator->GetValueAsDouble() == -0.0) {
862 const analysis::Constant* result =
863 FoldFPScalarDivideByZero(result_type, numerator, const_mgr);
864 if (result != nullptr)
Nicolas Capens84c9c452022-11-18 14:11:05 +0000865 result = NegateFPConst(result_type, result, const_mgr);
Nicolas Capens6cacf182021-11-30 11:15:46 -0500866 return result;
867 } else {
868 return FOLD_FPARITH_OP(/)(result_type, numerator, denominator, const_mgr);
869 }
870}
871
872// Returns the constant folding rule to fold |OpFDiv| with two constants.
873ConstantFoldingRule FoldFDiv() { return FoldFPBinaryOp(FoldScalarFPDivide); }
Chris Forbescc5697f2019-01-30 11:54:08 -0800874
875bool CompareFloatingPoint(bool op_result, bool op_unordered,
876 bool need_ordered) {
877 if (need_ordered) {
878 // operands are ordered and Operand 1 is |op| Operand 2
879 return !op_unordered && op_result;
880 } else {
881 // operands are unordered or Operand 1 is |op| Operand 2
882 return op_unordered || op_result;
883 }
884}
885
886// This macro defines a |BinaryScalarFoldingRule| that applies |op|. The
887// operator |op| must work for both float and double, and use syntax "f1 op f2".
888#define FOLD_FPCMP_OP(op, ord) \
889 [](const analysis::Type* result_type, const analysis::Constant* a, \
890 const analysis::Constant* b, \
891 analysis::ConstantManager* const_mgr) -> const analysis::Constant* { \
892 assert(result_type != nullptr && a != nullptr && b != nullptr); \
893 assert(result_type->AsBool()); \
894 assert(a->type() == b->type()); \
895 const analysis::Float* float_type = a->type()->AsFloat(); \
896 assert(float_type != nullptr); \
897 if (float_type->width() == 32) { \
898 float fa = a->GetFloat(); \
899 float fb = b->GetFloat(); \
900 bool result = CompareFloatingPoint( \
901 fa op fb, std::isnan(fa) || std::isnan(fb), ord); \
902 std::vector<uint32_t> words = {uint32_t(result)}; \
903 return const_mgr->GetConstant(result_type, words); \
904 } else if (float_type->width() == 64) { \
905 double fa = a->GetDouble(); \
906 double fb = b->GetDouble(); \
907 bool result = CompareFloatingPoint( \
908 fa op fb, std::isnan(fa) || std::isnan(fb), ord); \
909 std::vector<uint32_t> words = {uint32_t(result)}; \
910 return const_mgr->GetConstant(result_type, words); \
911 } \
912 return nullptr; \
913 }
914
915// Define the folding rules for ordered and unordered comparison for floating
916// point values.
917ConstantFoldingRule FoldFOrdEqual() {
918 return FoldFPBinaryOp(FOLD_FPCMP_OP(==, true));
919}
920ConstantFoldingRule FoldFUnordEqual() {
921 return FoldFPBinaryOp(FOLD_FPCMP_OP(==, false));
922}
923ConstantFoldingRule FoldFOrdNotEqual() {
924 return FoldFPBinaryOp(FOLD_FPCMP_OP(!=, true));
925}
926ConstantFoldingRule FoldFUnordNotEqual() {
927 return FoldFPBinaryOp(FOLD_FPCMP_OP(!=, false));
928}
929ConstantFoldingRule FoldFOrdLessThan() {
930 return FoldFPBinaryOp(FOLD_FPCMP_OP(<, true));
931}
932ConstantFoldingRule FoldFUnordLessThan() {
933 return FoldFPBinaryOp(FOLD_FPCMP_OP(<, false));
934}
935ConstantFoldingRule FoldFOrdGreaterThan() {
936 return FoldFPBinaryOp(FOLD_FPCMP_OP(>, true));
937}
938ConstantFoldingRule FoldFUnordGreaterThan() {
939 return FoldFPBinaryOp(FOLD_FPCMP_OP(>, false));
940}
941ConstantFoldingRule FoldFOrdLessThanEqual() {
942 return FoldFPBinaryOp(FOLD_FPCMP_OP(<=, true));
943}
944ConstantFoldingRule FoldFUnordLessThanEqual() {
945 return FoldFPBinaryOp(FOLD_FPCMP_OP(<=, false));
946}
947ConstantFoldingRule FoldFOrdGreaterThanEqual() {
948 return FoldFPBinaryOp(FOLD_FPCMP_OP(>=, true));
949}
950ConstantFoldingRule FoldFUnordGreaterThanEqual() {
951 return FoldFPBinaryOp(FOLD_FPCMP_OP(>=, false));
952}
953
954// Folds an OpDot where all of the inputs are constants to a
955// constant. A new constant is created if necessary.
956ConstantFoldingRule FoldOpDotWithConstants() {
957 return [](IRContext* context, Instruction* inst,
958 const std::vector<const analysis::Constant*>& constants)
959 -> const analysis::Constant* {
960 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
961 analysis::TypeManager* type_mgr = context->get_type_mgr();
962 const analysis::Type* new_type = type_mgr->GetType(inst->type_id());
963 assert(new_type->AsFloat() && "OpDot should have a float return type.");
964 const analysis::Float* float_type = new_type->AsFloat();
965
966 if (!inst->IsFloatingPointFoldingAllowed()) {
967 return nullptr;
968 }
969
970 // If one of the operands is 0, then the result is 0.
971 bool has_zero_operand = false;
972
973 for (int i = 0; i < 2; ++i) {
974 if (constants[i]) {
975 if (constants[i]->AsNullConstant() ||
976 constants[i]->AsVectorConstant()->IsZero()) {
977 has_zero_operand = true;
978 break;
979 }
980 }
981 }
982
983 if (has_zero_operand) {
984 if (float_type->width() == 32) {
985 utils::FloatProxy<float> result(0.0f);
986 std::vector<uint32_t> words = result.GetWords();
987 return const_mgr->GetConstant(float_type, words);
988 }
989 if (float_type->width() == 64) {
990 utils::FloatProxy<double> result(0.0);
991 std::vector<uint32_t> words = result.GetWords();
992 return const_mgr->GetConstant(float_type, words);
993 }
994 return nullptr;
995 }
996
997 if (constants[0] == nullptr || constants[1] == nullptr) {
998 return nullptr;
999 }
1000
1001 std::vector<const analysis::Constant*> a_components;
1002 std::vector<const analysis::Constant*> b_components;
1003
1004 a_components = constants[0]->GetVectorComponents(const_mgr);
1005 b_components = constants[1]->GetVectorComponents(const_mgr);
1006
1007 utils::FloatProxy<double> result(0.0);
1008 std::vector<uint32_t> words = result.GetWords();
1009 const analysis::Constant* result_const =
1010 const_mgr->GetConstant(float_type, words);
Ben Claytonb73b7602019-07-29 13:56:13 +01001011 for (uint32_t i = 0; i < a_components.size() && result_const != nullptr;
1012 ++i) {
Chris Forbescc5697f2019-01-30 11:54:08 -08001013 if (a_components[i] == nullptr || b_components[i] == nullptr) {
1014 return nullptr;
1015 }
1016
1017 const analysis::Constant* component = FOLD_FPARITH_OP(*)(
1018 new_type, a_components[i], b_components[i], const_mgr);
Ben Claytonb73b7602019-07-29 13:56:13 +01001019 if (component == nullptr) {
1020 return nullptr;
1021 }
Chris Forbescc5697f2019-01-30 11:54:08 -08001022 result_const =
1023 FOLD_FPARITH_OP(+)(new_type, result_const, component, const_mgr);
1024 }
1025 return result_const;
1026 };
1027}
1028
1029// This function defines a |UnaryScalarFoldingRule| that subtracts the constant
1030// from zero.
1031UnaryScalarFoldingRule FoldFNegateOp() {
1032 return [](const analysis::Type* result_type, const analysis::Constant* a,
1033 analysis::ConstantManager* const_mgr) -> const analysis::Constant* {
1034 assert(result_type != nullptr && a != nullptr);
1035 assert(result_type == a->type());
Nicolas Capens84c9c452022-11-18 14:11:05 +00001036 return NegateFPConst(result_type, a, const_mgr);
Chris Forbescc5697f2019-01-30 11:54:08 -08001037 };
1038}
1039
1040ConstantFoldingRule FoldFNegate() { return FoldFPUnaryOp(FoldFNegateOp()); }
1041
Nicolas Capens84c9c452022-11-18 14:11:05 +00001042ConstantFoldingRule FoldFClampFeedingCompare(spv::Op cmp_opcode) {
Chris Forbescc5697f2019-01-30 11:54:08 -08001043 return [cmp_opcode](IRContext* context, Instruction* inst,
1044 const std::vector<const analysis::Constant*>& constants)
1045 -> const analysis::Constant* {
1046 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
1047 analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
1048
1049 if (!inst->IsFloatingPointFoldingAllowed()) {
1050 return nullptr;
1051 }
1052
1053 uint32_t non_const_idx = (constants[0] ? 1 : 0);
1054 uint32_t operand_id = inst->GetSingleWordInOperand(non_const_idx);
1055 Instruction* operand_inst = def_use_mgr->GetDef(operand_id);
1056
1057 analysis::TypeManager* type_mgr = context->get_type_mgr();
1058 const analysis::Type* operand_type =
1059 type_mgr->GetType(operand_inst->type_id());
1060
1061 if (!operand_type->AsFloat()) {
1062 return nullptr;
1063 }
1064
1065 if (operand_type->AsFloat()->width() != 32 &&
1066 operand_type->AsFloat()->width() != 64) {
1067 return nullptr;
1068 }
1069
Nicolas Capens84c9c452022-11-18 14:11:05 +00001070 if (operand_inst->opcode() != spv::Op::OpExtInst) {
Chris Forbescc5697f2019-01-30 11:54:08 -08001071 return nullptr;
1072 }
1073
1074 if (operand_inst->GetSingleWordInOperand(1) != GLSLstd450FClamp) {
1075 return nullptr;
1076 }
1077
1078 if (constants[1] == nullptr && constants[0] == nullptr) {
1079 return nullptr;
1080 }
1081
1082 uint32_t max_id = operand_inst->GetSingleWordInOperand(4);
1083 const analysis::Constant* max_const =
1084 const_mgr->FindDeclaredConstant(max_id);
1085
1086 uint32_t min_id = operand_inst->GetSingleWordInOperand(3);
1087 const analysis::Constant* min_const =
1088 const_mgr->FindDeclaredConstant(min_id);
1089
1090 bool found_result = false;
1091 bool result = false;
1092
1093 switch (cmp_opcode) {
Nicolas Capens84c9c452022-11-18 14:11:05 +00001094 case spv::Op::OpFOrdLessThan:
1095 case spv::Op::OpFUnordLessThan:
1096 case spv::Op::OpFOrdGreaterThanEqual:
1097 case spv::Op::OpFUnordGreaterThanEqual:
Chris Forbescc5697f2019-01-30 11:54:08 -08001098 if (constants[0]) {
1099 if (min_const) {
1100 if (constants[0]->GetValueAsDouble() <
1101 min_const->GetValueAsDouble()) {
1102 found_result = true;
Nicolas Capens84c9c452022-11-18 14:11:05 +00001103 result = (cmp_opcode == spv::Op::OpFOrdLessThan ||
1104 cmp_opcode == spv::Op::OpFUnordLessThan);
Chris Forbescc5697f2019-01-30 11:54:08 -08001105 }
1106 }
1107 if (max_const) {
1108 if (constants[0]->GetValueAsDouble() >=
1109 max_const->GetValueAsDouble()) {
1110 found_result = true;
Nicolas Capens84c9c452022-11-18 14:11:05 +00001111 result = !(cmp_opcode == spv::Op::OpFOrdLessThan ||
1112 cmp_opcode == spv::Op::OpFUnordLessThan);
Chris Forbescc5697f2019-01-30 11:54:08 -08001113 }
1114 }
1115 }
1116
1117 if (constants[1]) {
1118 if (max_const) {
1119 if (max_const->GetValueAsDouble() <
1120 constants[1]->GetValueAsDouble()) {
1121 found_result = true;
Nicolas Capens84c9c452022-11-18 14:11:05 +00001122 result = (cmp_opcode == spv::Op::OpFOrdLessThan ||
1123 cmp_opcode == spv::Op::OpFUnordLessThan);
Chris Forbescc5697f2019-01-30 11:54:08 -08001124 }
1125 }
1126
1127 if (min_const) {
1128 if (min_const->GetValueAsDouble() >=
1129 constants[1]->GetValueAsDouble()) {
1130 found_result = true;
Nicolas Capens84c9c452022-11-18 14:11:05 +00001131 result = !(cmp_opcode == spv::Op::OpFOrdLessThan ||
1132 cmp_opcode == spv::Op::OpFUnordLessThan);
Chris Forbescc5697f2019-01-30 11:54:08 -08001133 }
1134 }
1135 }
1136 break;
Nicolas Capens84c9c452022-11-18 14:11:05 +00001137 case spv::Op::OpFOrdGreaterThan:
1138 case spv::Op::OpFUnordGreaterThan:
1139 case spv::Op::OpFOrdLessThanEqual:
1140 case spv::Op::OpFUnordLessThanEqual:
Chris Forbescc5697f2019-01-30 11:54:08 -08001141 if (constants[0]) {
1142 if (min_const) {
1143 if (constants[0]->GetValueAsDouble() <=
1144 min_const->GetValueAsDouble()) {
1145 found_result = true;
Nicolas Capens84c9c452022-11-18 14:11:05 +00001146 result = (cmp_opcode == spv::Op::OpFOrdLessThanEqual ||
1147 cmp_opcode == spv::Op::OpFUnordLessThanEqual);
Chris Forbescc5697f2019-01-30 11:54:08 -08001148 }
1149 }
1150 if (max_const) {
1151 if (constants[0]->GetValueAsDouble() >
1152 max_const->GetValueAsDouble()) {
1153 found_result = true;
Nicolas Capens84c9c452022-11-18 14:11:05 +00001154 result = !(cmp_opcode == spv::Op::OpFOrdLessThanEqual ||
1155 cmp_opcode == spv::Op::OpFUnordLessThanEqual);
Chris Forbescc5697f2019-01-30 11:54:08 -08001156 }
1157 }
1158 }
1159
1160 if (constants[1]) {
1161 if (max_const) {
1162 if (max_const->GetValueAsDouble() <=
1163 constants[1]->GetValueAsDouble()) {
1164 found_result = true;
Nicolas Capens84c9c452022-11-18 14:11:05 +00001165 result = (cmp_opcode == spv::Op::OpFOrdLessThanEqual ||
1166 cmp_opcode == spv::Op::OpFUnordLessThanEqual);
Chris Forbescc5697f2019-01-30 11:54:08 -08001167 }
1168 }
1169
1170 if (min_const) {
1171 if (min_const->GetValueAsDouble() >
1172 constants[1]->GetValueAsDouble()) {
1173 found_result = true;
Nicolas Capens84c9c452022-11-18 14:11:05 +00001174 result = !(cmp_opcode == spv::Op::OpFOrdLessThanEqual ||
1175 cmp_opcode == spv::Op::OpFUnordLessThanEqual);
Chris Forbescc5697f2019-01-30 11:54:08 -08001176 }
1177 }
1178 }
1179 break;
1180 default:
1181 return nullptr;
1182 }
1183
1184 if (!found_result) {
1185 return nullptr;
1186 }
1187
1188 const analysis::Type* bool_type =
1189 context->get_type_mgr()->GetType(inst->type_id());
1190 const analysis::Constant* result_const =
1191 const_mgr->GetConstant(bool_type, {static_cast<uint32_t>(result)});
1192 assert(result_const);
1193 return result_const;
1194 };
1195}
1196
Ben Claytond0f684e2019-08-30 22:36:08 +01001197ConstantFoldingRule FoldFMix() {
1198 return [](IRContext* context, Instruction* inst,
1199 const std::vector<const analysis::Constant*>& constants)
1200 -> const analysis::Constant* {
1201 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
Nicolas Capens84c9c452022-11-18 14:11:05 +00001202 assert(inst->opcode() == spv::Op::OpExtInst &&
Ben Claytond0f684e2019-08-30 22:36:08 +01001203 "Expecting an extended instruction.");
1204 assert(inst->GetSingleWordInOperand(0) ==
1205 context->get_feature_mgr()->GetExtInstImportId_GLSLstd450() &&
1206 "Expecting a GLSLstd450 extended instruction.");
1207 assert(inst->GetSingleWordInOperand(1) == GLSLstd450FMix &&
1208 "Expecting and FMix instruction.");
1209
1210 if (!inst->IsFloatingPointFoldingAllowed()) {
1211 return nullptr;
1212 }
1213
1214 // Make sure all FMix operands are constants.
1215 for (uint32_t i = 1; i < 4; i++) {
1216 if (constants[i] == nullptr) {
1217 return nullptr;
1218 }
1219 }
1220
1221 const analysis::Constant* one;
Ben Claytond552f632019-11-18 11:18:41 +00001222 bool is_vector = false;
1223 const analysis::Type* result_type = constants[1]->type();
1224 const analysis::Type* base_type = result_type;
1225 if (base_type->AsVector()) {
1226 is_vector = true;
1227 base_type = base_type->AsVector()->element_type();
1228 }
1229 assert(base_type->AsFloat() != nullptr &&
1230 "FMix is suppose to act on floats or vectors of floats.");
1231
1232 if (base_type->AsFloat()->width() == 32) {
1233 one = const_mgr->GetConstant(base_type,
Ben Claytond0f684e2019-08-30 22:36:08 +01001234 utils::FloatProxy<float>(1.0f).GetWords());
1235 } else {
Ben Claytond552f632019-11-18 11:18:41 +00001236 one = const_mgr->GetConstant(base_type,
Ben Claytond0f684e2019-08-30 22:36:08 +01001237 utils::FloatProxy<double>(1.0).GetWords());
1238 }
1239
Ben Claytond552f632019-11-18 11:18:41 +00001240 if (is_vector) {
1241 uint32_t one_id = const_mgr->GetDefiningInstruction(one)->result_id();
1242 one =
1243 const_mgr->GetConstant(result_type, std::vector<uint32_t>(4, one_id));
1244 }
1245
1246 const analysis::Constant* temp1 = FoldFPBinaryOp(
1247 FOLD_FPARITH_OP(-), inst->type_id(), {one, constants[3]}, context);
Ben Claytond0f684e2019-08-30 22:36:08 +01001248 if (temp1 == nullptr) {
1249 return nullptr;
1250 }
1251
Ben Claytond552f632019-11-18 11:18:41 +00001252 const analysis::Constant* temp2 = FoldFPBinaryOp(
1253 FOLD_FPARITH_OP(*), inst->type_id(), {constants[1], temp1}, context);
Ben Claytond0f684e2019-08-30 22:36:08 +01001254 if (temp2 == nullptr) {
1255 return nullptr;
1256 }
Ben Claytond552f632019-11-18 11:18:41 +00001257 const analysis::Constant* temp3 =
1258 FoldFPBinaryOp(FOLD_FPARITH_OP(*), inst->type_id(),
1259 {constants[2], constants[3]}, context);
Ben Claytond0f684e2019-08-30 22:36:08 +01001260 if (temp3 == nullptr) {
1261 return nullptr;
1262 }
Ben Claytond552f632019-11-18 11:18:41 +00001263 return FoldFPBinaryOp(FOLD_FPARITH_OP(+), inst->type_id(), {temp2, temp3},
1264 context);
Ben Claytond0f684e2019-08-30 22:36:08 +01001265 };
1266}
1267
Ben Claytond552f632019-11-18 11:18:41 +00001268const analysis::Constant* FoldMin(const analysis::Type* result_type,
1269 const analysis::Constant* a,
1270 const analysis::Constant* b,
1271 analysis::ConstantManager*) {
1272 if (const analysis::Integer* int_type = result_type->AsInteger()) {
1273 if (int_type->width() == 32) {
1274 if (int_type->IsSigned()) {
1275 int32_t va = a->GetS32();
1276 int32_t vb = b->GetS32();
1277 return (va < vb ? a : b);
1278 } else {
1279 uint32_t va = a->GetU32();
1280 uint32_t vb = b->GetU32();
1281 return (va < vb ? a : b);
1282 }
1283 } else if (int_type->width() == 64) {
1284 if (int_type->IsSigned()) {
1285 int64_t va = a->GetS64();
1286 int64_t vb = b->GetS64();
1287 return (va < vb ? a : b);
1288 } else {
1289 uint64_t va = a->GetU64();
1290 uint64_t vb = b->GetU64();
1291 return (va < vb ? a : b);
1292 }
1293 }
1294 } else if (const analysis::Float* float_type = result_type->AsFloat()) {
1295 if (float_type->width() == 32) {
1296 float va = a->GetFloat();
1297 float vb = b->GetFloat();
1298 return (va < vb ? a : b);
1299 } else if (float_type->width() == 64) {
1300 double va = a->GetDouble();
1301 double vb = b->GetDouble();
1302 return (va < vb ? a : b);
1303 }
1304 }
1305 return nullptr;
1306}
1307
1308const analysis::Constant* FoldMax(const analysis::Type* result_type,
1309 const analysis::Constant* a,
1310 const analysis::Constant* b,
1311 analysis::ConstantManager*) {
1312 if (const analysis::Integer* int_type = result_type->AsInteger()) {
1313 if (int_type->width() == 32) {
1314 if (int_type->IsSigned()) {
1315 int32_t va = a->GetS32();
1316 int32_t vb = b->GetS32();
1317 return (va > vb ? a : b);
1318 } else {
1319 uint32_t va = a->GetU32();
1320 uint32_t vb = b->GetU32();
1321 return (va > vb ? a : b);
1322 }
1323 } else if (int_type->width() == 64) {
1324 if (int_type->IsSigned()) {
1325 int64_t va = a->GetS64();
1326 int64_t vb = b->GetS64();
1327 return (va > vb ? a : b);
1328 } else {
1329 uint64_t va = a->GetU64();
1330 uint64_t vb = b->GetU64();
1331 return (va > vb ? a : b);
1332 }
1333 }
1334 } else if (const analysis::Float* float_type = result_type->AsFloat()) {
1335 if (float_type->width() == 32) {
1336 float va = a->GetFloat();
1337 float vb = b->GetFloat();
1338 return (va > vb ? a : b);
1339 } else if (float_type->width() == 64) {
1340 double va = a->GetDouble();
1341 double vb = b->GetDouble();
1342 return (va > vb ? a : b);
1343 }
1344 }
1345 return nullptr;
1346}
1347
1348// Fold an clamp instruction when all three operands are constant.
1349const analysis::Constant* FoldClamp1(
1350 IRContext* context, Instruction* inst,
1351 const std::vector<const analysis::Constant*>& constants) {
Nicolas Capens84c9c452022-11-18 14:11:05 +00001352 assert(inst->opcode() == spv::Op::OpExtInst &&
Ben Claytond552f632019-11-18 11:18:41 +00001353 "Expecting an extended instruction.");
1354 assert(inst->GetSingleWordInOperand(0) ==
1355 context->get_feature_mgr()->GetExtInstImportId_GLSLstd450() &&
1356 "Expecting a GLSLstd450 extended instruction.");
1357
1358 // Make sure all Clamp operands are constants.
Alexis Hetu00e0af12021-11-08 08:57:46 -05001359 for (uint32_t i = 1; i < 4; i++) {
Ben Claytond552f632019-11-18 11:18:41 +00001360 if (constants[i] == nullptr) {
1361 return nullptr;
1362 }
1363 }
1364
1365 const analysis::Constant* temp = FoldFPBinaryOp(
1366 FoldMax, inst->type_id(), {constants[1], constants[2]}, context);
1367 if (temp == nullptr) {
1368 return nullptr;
1369 }
1370 return FoldFPBinaryOp(FoldMin, inst->type_id(), {temp, constants[3]},
1371 context);
1372}
1373
Alexis Hetu00e0af12021-11-08 08:57:46 -05001374// Fold a clamp instruction when |x <= min_val|.
Ben Claytond552f632019-11-18 11:18:41 +00001375const analysis::Constant* FoldClamp2(
1376 IRContext* context, Instruction* inst,
1377 const std::vector<const analysis::Constant*>& constants) {
Nicolas Capens84c9c452022-11-18 14:11:05 +00001378 assert(inst->opcode() == spv::Op::OpExtInst &&
Ben Claytond552f632019-11-18 11:18:41 +00001379 "Expecting an extended instruction.");
1380 assert(inst->GetSingleWordInOperand(0) ==
1381 context->get_feature_mgr()->GetExtInstImportId_GLSLstd450() &&
1382 "Expecting a GLSLstd450 extended instruction.");
1383
1384 const analysis::Constant* x = constants[1];
1385 const analysis::Constant* min_val = constants[2];
1386
1387 if (x == nullptr || min_val == nullptr) {
1388 return nullptr;
1389 }
1390
1391 const analysis::Constant* temp =
1392 FoldFPBinaryOp(FoldMax, inst->type_id(), {x, min_val}, context);
1393 if (temp == min_val) {
1394 // We can assume that |min_val| is less than |max_val|. Therefore, if the
1395 // result of the max operation is |min_val|, we know the result of the min
1396 // operation, even if |max_val| is not a constant.
1397 return min_val;
1398 }
1399 return nullptr;
1400}
1401
1402// Fold a clamp instruction when |x >= max_val|.
1403const analysis::Constant* FoldClamp3(
1404 IRContext* context, Instruction* inst,
1405 const std::vector<const analysis::Constant*>& constants) {
Nicolas Capens84c9c452022-11-18 14:11:05 +00001406 assert(inst->opcode() == spv::Op::OpExtInst &&
Ben Claytond552f632019-11-18 11:18:41 +00001407 "Expecting an extended instruction.");
1408 assert(inst->GetSingleWordInOperand(0) ==
1409 context->get_feature_mgr()->GetExtInstImportId_GLSLstd450() &&
1410 "Expecting a GLSLstd450 extended instruction.");
1411
1412 const analysis::Constant* x = constants[1];
1413 const analysis::Constant* max_val = constants[3];
1414
1415 if (x == nullptr || max_val == nullptr) {
1416 return nullptr;
1417 }
1418
1419 const analysis::Constant* temp =
1420 FoldFPBinaryOp(FoldMin, inst->type_id(), {x, max_val}, context);
1421 if (temp == max_val) {
1422 // We can assume that |min_val| is less than |max_val|. Therefore, if the
1423 // result of the max operation is |min_val|, we know the result of the min
1424 // operation, even if |max_val| is not a constant.
1425 return max_val;
1426 }
1427 return nullptr;
1428}
1429
Ben Claytondc6b76a2020-02-24 14:53:40 +00001430UnaryScalarFoldingRule FoldFTranscendentalUnary(double (*fp)(double)) {
1431 return
1432 [fp](const analysis::Type* result_type, const analysis::Constant* a,
1433 analysis::ConstantManager* const_mgr) -> const analysis::Constant* {
1434 assert(result_type != nullptr && a != nullptr);
1435 const analysis::Float* float_type = a->type()->AsFloat();
1436 assert(float_type != nullptr);
1437 assert(float_type == result_type->AsFloat());
1438 if (float_type->width() == 32) {
1439 float fa = a->GetFloat();
1440 float res = static_cast<float>(fp(fa));
1441 utils::FloatProxy<float> result(res);
1442 std::vector<uint32_t> words = result.GetWords();
1443 return const_mgr->GetConstant(result_type, words);
1444 } else if (float_type->width() == 64) {
1445 double fa = a->GetDouble();
1446 double res = fp(fa);
1447 utils::FloatProxy<double> result(res);
1448 std::vector<uint32_t> words = result.GetWords();
1449 return const_mgr->GetConstant(result_type, words);
1450 }
1451 return nullptr;
1452 };
1453}
1454
1455BinaryScalarFoldingRule FoldFTranscendentalBinary(double (*fp)(double,
1456 double)) {
1457 return
1458 [fp](const analysis::Type* result_type, const analysis::Constant* a,
1459 const analysis::Constant* b,
1460 analysis::ConstantManager* const_mgr) -> const analysis::Constant* {
1461 assert(result_type != nullptr && a != nullptr);
1462 const analysis::Float* float_type = a->type()->AsFloat();
1463 assert(float_type != nullptr);
1464 assert(float_type == result_type->AsFloat());
1465 assert(float_type == b->type()->AsFloat());
1466 if (float_type->width() == 32) {
1467 float fa = a->GetFloat();
1468 float fb = b->GetFloat();
1469 float res = static_cast<float>(fp(fa, fb));
1470 utils::FloatProxy<float> result(res);
1471 std::vector<uint32_t> words = result.GetWords();
1472 return const_mgr->GetConstant(result_type, words);
1473 } else if (float_type->width() == 64) {
1474 double fa = a->GetDouble();
1475 double fb = b->GetDouble();
1476 double res = fp(fa, fb);
1477 utils::FloatProxy<double> result(res);
1478 std::vector<uint32_t> words = result.GetWords();
1479 return const_mgr->GetConstant(result_type, words);
1480 }
1481 return nullptr;
1482 };
1483}
Chris Forbescc5697f2019-01-30 11:54:08 -08001484} // namespace
1485
Ben Claytond0f684e2019-08-30 22:36:08 +01001486void ConstantFoldingRules::AddFoldingRules() {
Chris Forbescc5697f2019-01-30 11:54:08 -08001487 // Add all folding rules to the list for the opcodes to which they apply.
1488 // Note that the order in which rules are added to the list matters. If a rule
1489 // applies to the instruction, the rest of the rules will not be attempted.
1490 // Take that into consideration.
1491
Nicolas Capens84c9c452022-11-18 14:11:05 +00001492 rules_[spv::Op::OpCompositeConstruct].push_back(FoldCompositeWithConstants());
Chris Forbescc5697f2019-01-30 11:54:08 -08001493
Nicolas Capens84c9c452022-11-18 14:11:05 +00001494 rules_[spv::Op::OpCompositeExtract].push_back(FoldExtractWithConstants());
1495 rules_[spv::Op::OpCompositeInsert].push_back(FoldInsertWithConstants());
Chris Forbescc5697f2019-01-30 11:54:08 -08001496
Nicolas Capens84c9c452022-11-18 14:11:05 +00001497 rules_[spv::Op::OpConvertFToS].push_back(FoldFToI());
1498 rules_[spv::Op::OpConvertFToU].push_back(FoldFToI());
1499 rules_[spv::Op::OpConvertSToF].push_back(FoldIToF());
1500 rules_[spv::Op::OpConvertUToF].push_back(FoldIToF());
Chris Forbescc5697f2019-01-30 11:54:08 -08001501
Nicolas Capens84c9c452022-11-18 14:11:05 +00001502 rules_[spv::Op::OpDot].push_back(FoldOpDotWithConstants());
1503 rules_[spv::Op::OpFAdd].push_back(FoldFAdd());
1504 rules_[spv::Op::OpFDiv].push_back(FoldFDiv());
1505 rules_[spv::Op::OpFMul].push_back(FoldFMul());
1506 rules_[spv::Op::OpFSub].push_back(FoldFSub());
Chris Forbescc5697f2019-01-30 11:54:08 -08001507
Nicolas Capens84c9c452022-11-18 14:11:05 +00001508 rules_[spv::Op::OpFOrdEqual].push_back(FoldFOrdEqual());
Chris Forbescc5697f2019-01-30 11:54:08 -08001509
Nicolas Capens84c9c452022-11-18 14:11:05 +00001510 rules_[spv::Op::OpFUnordEqual].push_back(FoldFUnordEqual());
Chris Forbescc5697f2019-01-30 11:54:08 -08001511
Nicolas Capens84c9c452022-11-18 14:11:05 +00001512 rules_[spv::Op::OpFOrdNotEqual].push_back(FoldFOrdNotEqual());
Chris Forbescc5697f2019-01-30 11:54:08 -08001513
Nicolas Capens84c9c452022-11-18 14:11:05 +00001514 rules_[spv::Op::OpFUnordNotEqual].push_back(FoldFUnordNotEqual());
Chris Forbescc5697f2019-01-30 11:54:08 -08001515
Nicolas Capens84c9c452022-11-18 14:11:05 +00001516 rules_[spv::Op::OpFOrdLessThan].push_back(FoldFOrdLessThan());
1517 rules_[spv::Op::OpFOrdLessThan].push_back(
1518 FoldFClampFeedingCompare(spv::Op::OpFOrdLessThan));
Chris Forbescc5697f2019-01-30 11:54:08 -08001519
Nicolas Capens84c9c452022-11-18 14:11:05 +00001520 rules_[spv::Op::OpFUnordLessThan].push_back(FoldFUnordLessThan());
1521 rules_[spv::Op::OpFUnordLessThan].push_back(
1522 FoldFClampFeedingCompare(spv::Op::OpFUnordLessThan));
Chris Forbescc5697f2019-01-30 11:54:08 -08001523
Nicolas Capens84c9c452022-11-18 14:11:05 +00001524 rules_[spv::Op::OpFOrdGreaterThan].push_back(FoldFOrdGreaterThan());
1525 rules_[spv::Op::OpFOrdGreaterThan].push_back(
1526 FoldFClampFeedingCompare(spv::Op::OpFOrdGreaterThan));
Chris Forbescc5697f2019-01-30 11:54:08 -08001527
Nicolas Capens84c9c452022-11-18 14:11:05 +00001528 rules_[spv::Op::OpFUnordGreaterThan].push_back(FoldFUnordGreaterThan());
1529 rules_[spv::Op::OpFUnordGreaterThan].push_back(
1530 FoldFClampFeedingCompare(spv::Op::OpFUnordGreaterThan));
Chris Forbescc5697f2019-01-30 11:54:08 -08001531
Nicolas Capens84c9c452022-11-18 14:11:05 +00001532 rules_[spv::Op::OpFOrdLessThanEqual].push_back(FoldFOrdLessThanEqual());
1533 rules_[spv::Op::OpFOrdLessThanEqual].push_back(
1534 FoldFClampFeedingCompare(spv::Op::OpFOrdLessThanEqual));
Chris Forbescc5697f2019-01-30 11:54:08 -08001535
Nicolas Capens84c9c452022-11-18 14:11:05 +00001536 rules_[spv::Op::OpFUnordLessThanEqual].push_back(FoldFUnordLessThanEqual());
1537 rules_[spv::Op::OpFUnordLessThanEqual].push_back(
1538 FoldFClampFeedingCompare(spv::Op::OpFUnordLessThanEqual));
Chris Forbescc5697f2019-01-30 11:54:08 -08001539
Nicolas Capens84c9c452022-11-18 14:11:05 +00001540 rules_[spv::Op::OpFOrdGreaterThanEqual].push_back(FoldFOrdGreaterThanEqual());
1541 rules_[spv::Op::OpFOrdGreaterThanEqual].push_back(
1542 FoldFClampFeedingCompare(spv::Op::OpFOrdGreaterThanEqual));
Chris Forbescc5697f2019-01-30 11:54:08 -08001543
Nicolas Capens84c9c452022-11-18 14:11:05 +00001544 rules_[spv::Op::OpFUnordGreaterThanEqual].push_back(
1545 FoldFUnordGreaterThanEqual());
1546 rules_[spv::Op::OpFUnordGreaterThanEqual].push_back(
1547 FoldFClampFeedingCompare(spv::Op::OpFUnordGreaterThanEqual));
Chris Forbescc5697f2019-01-30 11:54:08 -08001548
Nicolas Capens84c9c452022-11-18 14:11:05 +00001549 rules_[spv::Op::OpVectorShuffle].push_back(FoldVectorShuffleWithConstants());
1550 rules_[spv::Op::OpVectorTimesScalar].push_back(FoldVectorTimesScalar());
1551 rules_[spv::Op::OpVectorTimesMatrix].push_back(FoldVectorTimesMatrix());
1552 rules_[spv::Op::OpMatrixTimesVector].push_back(FoldMatrixTimesVector());
Chris Forbescc5697f2019-01-30 11:54:08 -08001553
Nicolas Capens84c9c452022-11-18 14:11:05 +00001554 rules_[spv::Op::OpFNegate].push_back(FoldFNegate());
1555 rules_[spv::Op::OpQuantizeToF16].push_back(FoldQuantizeToF16());
Ben Claytond0f684e2019-08-30 22:36:08 +01001556
1557 // Add rules for GLSLstd450
1558 FeatureManager* feature_manager = context_->get_feature_mgr();
1559 uint32_t ext_inst_glslstd450_id =
1560 feature_manager->GetExtInstImportId_GLSLstd450();
1561 if (ext_inst_glslstd450_id != 0) {
1562 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450FMix}].push_back(FoldFMix());
Ben Claytond552f632019-11-18 11:18:41 +00001563 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450SMin}].push_back(
1564 FoldFPBinaryOp(FoldMin));
1565 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450UMin}].push_back(
1566 FoldFPBinaryOp(FoldMin));
1567 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450FMin}].push_back(
1568 FoldFPBinaryOp(FoldMin));
1569 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450SMax}].push_back(
1570 FoldFPBinaryOp(FoldMax));
1571 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450UMax}].push_back(
1572 FoldFPBinaryOp(FoldMax));
1573 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450FMax}].push_back(
1574 FoldFPBinaryOp(FoldMax));
1575 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450UClamp}].push_back(
1576 FoldClamp1);
1577 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450UClamp}].push_back(
1578 FoldClamp2);
1579 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450UClamp}].push_back(
1580 FoldClamp3);
1581 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450SClamp}].push_back(
1582 FoldClamp1);
1583 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450SClamp}].push_back(
1584 FoldClamp2);
1585 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450SClamp}].push_back(
1586 FoldClamp3);
1587 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450FClamp}].push_back(
1588 FoldClamp1);
1589 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450FClamp}].push_back(
1590 FoldClamp2);
1591 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450FClamp}].push_back(
1592 FoldClamp3);
Ben Claytondc6b76a2020-02-24 14:53:40 +00001593 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Sin}].push_back(
1594 FoldFPUnaryOp(FoldFTranscendentalUnary(std::sin)));
1595 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Cos}].push_back(
1596 FoldFPUnaryOp(FoldFTranscendentalUnary(std::cos)));
1597 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Tan}].push_back(
1598 FoldFPUnaryOp(FoldFTranscendentalUnary(std::tan)));
1599 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Asin}].push_back(
1600 FoldFPUnaryOp(FoldFTranscendentalUnary(std::asin)));
1601 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Acos}].push_back(
1602 FoldFPUnaryOp(FoldFTranscendentalUnary(std::acos)));
1603 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Atan}].push_back(
1604 FoldFPUnaryOp(FoldFTranscendentalUnary(std::atan)));
1605 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Exp}].push_back(
1606 FoldFPUnaryOp(FoldFTranscendentalUnary(std::exp)));
1607 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Log}].push_back(
1608 FoldFPUnaryOp(FoldFTranscendentalUnary(std::log)));
1609
1610#ifdef __ANDROID__
sugoi1b398bf32022-02-18 10:27:28 -05001611 // Android NDK r15c targeting ABI 15 doesn't have full support for C++11
Ben Claytondc6b76a2020-02-24 14:53:40 +00001612 // (no std::exp2/log2). ::exp2 is available from C99 but ::log2 isn't
1613 // available up until ABI 18 so we use a shim
1614 auto log2_shim = [](double v) -> double { return log(v) / log(2.0); };
1615 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Exp2}].push_back(
1616 FoldFPUnaryOp(FoldFTranscendentalUnary(::exp2)));
1617 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Log2}].push_back(
1618 FoldFPUnaryOp(FoldFTranscendentalUnary(log2_shim)));
1619#else
1620 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Exp2}].push_back(
1621 FoldFPUnaryOp(FoldFTranscendentalUnary(std::exp2)));
1622 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Log2}].push_back(
1623 FoldFPUnaryOp(FoldFTranscendentalUnary(std::log2)));
1624#endif
1625
1626 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Sqrt}].push_back(
1627 FoldFPUnaryOp(FoldFTranscendentalUnary(std::sqrt)));
1628 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Atan2}].push_back(
1629 FoldFPBinaryOp(FoldFTranscendentalBinary(std::atan2)));
1630 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Pow}].push_back(
1631 FoldFPBinaryOp(FoldFTranscendentalBinary(std::pow)));
Ben Claytond0f684e2019-08-30 22:36:08 +01001632 }
Chris Forbescc5697f2019-01-30 11:54:08 -08001633}
1634} // namespace opt
1635} // namespace spvtools