blob: e91455ecdef2f55f3b7adf4943aec379e5cbd170 [file] [log] [blame]
Chris Forbescc5697f2019-01-30 11:54:08 -08001// Copyright (c) 2018 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#include "source/opt/const_folding_rules.h"
16
17#include "source/opt/ir_context.h"
18
19namespace spvtools {
20namespace opt {
21namespace {
Nicolas Capens84c9c452022-11-18 14:11:05 +000022constexpr uint32_t kExtractCompositeIdInIdx = 0;
Chris Forbescc5697f2019-01-30 11:54:08 -080023
Nicolas Capens6cacf182021-11-30 11:15:46 -050024// Returns a constants with the value NaN of the given type. Only works for
25// 32-bit and 64-bit float point types. Returns |nullptr| if an error occurs.
26const analysis::Constant* GetNan(const analysis::Type* type,
27 analysis::ConstantManager* const_mgr) {
28 const analysis::Float* float_type = type->AsFloat();
29 if (float_type == nullptr) {
30 return nullptr;
31 }
32
33 switch (float_type->width()) {
34 case 32:
35 return const_mgr->GetFloatConst(std::numeric_limits<float>::quiet_NaN());
36 case 64:
37 return const_mgr->GetDoubleConst(
38 std::numeric_limits<double>::quiet_NaN());
39 default:
40 return nullptr;
41 }
42}
43
44// Returns a constants with the value INF of the given type. Only works for
45// 32-bit and 64-bit float point types. Returns |nullptr| if an error occurs.
46const analysis::Constant* GetInf(const analysis::Type* type,
47 analysis::ConstantManager* const_mgr) {
48 const analysis::Float* float_type = type->AsFloat();
49 if (float_type == nullptr) {
50 return nullptr;
51 }
52
53 switch (float_type->width()) {
54 case 32:
55 return const_mgr->GetFloatConst(std::numeric_limits<float>::infinity());
56 case 64:
57 return const_mgr->GetDoubleConst(std::numeric_limits<double>::infinity());
58 default:
59 return nullptr;
60 }
61}
62
Chris Forbescc5697f2019-01-30 11:54:08 -080063// Returns true if |type| is Float or a vector of Float.
64bool HasFloatingPoint(const analysis::Type* type) {
65 if (type->AsFloat()) {
66 return true;
67 } else if (const analysis::Vector* vec_type = type->AsVector()) {
68 return vec_type->element_type()->AsFloat() != nullptr;
69 }
70
71 return false;
72}
73
Nicolas Capens6cacf182021-11-30 11:15:46 -050074// Returns a constants with the value |-val| of the given type. Only works for
75// 32-bit and 64-bit float point types. Returns |nullptr| if an error occurs.
Nicolas Capens84c9c452022-11-18 14:11:05 +000076const analysis::Constant* NegateFPConst(const analysis::Type* result_type,
Nicolas Capens6cacf182021-11-30 11:15:46 -050077 const analysis::Constant* val,
78 analysis::ConstantManager* const_mgr) {
79 const analysis::Float* float_type = result_type->AsFloat();
80 assert(float_type != nullptr);
81 if (float_type->width() == 32) {
82 float fa = val->GetFloat();
83 return const_mgr->GetFloatConst(-fa);
84 } else if (float_type->width() == 64) {
85 double da = val->GetDouble();
86 return const_mgr->GetDoubleConst(-da);
87 }
88 return nullptr;
89}
90
Chris Forbescc5697f2019-01-30 11:54:08 -080091// Folds an OpcompositeExtract where input is a composite constant.
92ConstantFoldingRule FoldExtractWithConstants() {
93 return [](IRContext* context, Instruction* inst,
94 const std::vector<const analysis::Constant*>& constants)
95 -> const analysis::Constant* {
96 const analysis::Constant* c = constants[kExtractCompositeIdInIdx];
97 if (c == nullptr) {
98 return nullptr;
99 }
100
101 for (uint32_t i = 1; i < inst->NumInOperands(); ++i) {
102 uint32_t element_index = inst->GetSingleWordInOperand(i);
103 if (c->AsNullConstant()) {
104 // Return Null for the return type.
105 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
106 analysis::TypeManager* type_mgr = context->get_type_mgr();
107 return const_mgr->GetConstant(type_mgr->GetType(inst->type_id()), {});
108 }
109
110 auto cc = c->AsCompositeConstant();
111 assert(cc != nullptr);
112 auto components = cc->GetComponents();
Ben Claytond0f684e2019-08-30 22:36:08 +0100113 // Protect against invalid IR. Refuse to fold if the index is out
114 // of bounds.
115 if (element_index >= components.size()) return nullptr;
Chris Forbescc5697f2019-01-30 11:54:08 -0800116 c = components[element_index];
117 }
118 return c;
119 };
120}
121
Nicolas Capens84c9c452022-11-18 14:11:05 +0000122// Folds an OpcompositeInsert where input is a composite constant.
123ConstantFoldingRule FoldInsertWithConstants() {
124 return [](IRContext* context, Instruction* inst,
125 const std::vector<const analysis::Constant*>& constants)
126 -> const analysis::Constant* {
127 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
128 const analysis::Constant* object = constants[0];
129 const analysis::Constant* composite = constants[1];
130 if (object == nullptr || composite == nullptr) {
131 return nullptr;
132 }
133
134 // If there is more than 1 index, then each additional constant used by the
135 // index will need to be recreated to use the inserted object.
136 std::vector<const analysis::Constant*> chain;
137 std::vector<const analysis::Constant*> components;
138 const analysis::Type* type = nullptr;
139
140 // Work down hierarchy and add all the indexes, not including the final
141 // index.
142 for (uint32_t i = 2; i < inst->NumInOperands(); ++i) {
143 if (i != inst->NumInOperands() - 1) {
144 chain.push_back(composite);
145 }
146 const uint32_t index = inst->GetSingleWordInOperand(i);
147 components = composite->AsCompositeConstant()->GetComponents();
148 type = composite->AsCompositeConstant()->type();
149 composite = components[index];
150 }
151
152 // Final index in hierarchy is inserted with new object.
153 const uint32_t final_index =
154 inst->GetSingleWordInOperand(inst->NumInOperands() - 1);
155 std::vector<uint32_t> ids;
156 for (size_t i = 0; i < components.size(); i++) {
157 const analysis::Constant* constant =
158 (i == final_index) ? object : components[i];
159 Instruction* member_inst = const_mgr->GetDefiningInstruction(constant);
160 ids.push_back(member_inst->result_id());
161 }
162 const analysis::Constant* new_constant = const_mgr->GetConstant(type, ids);
163
164 // Work backwards up the chain and replace each index with new constant.
165 for (size_t i = chain.size(); i > 0; i--) {
166 // Need to insert any previous instruction into the module first.
167 // Can't just insert in types_values_begin() because it will move above
168 // where the types are declared
169 for (Module::inst_iterator inst_iter = context->types_values_begin();
170 inst_iter != context->types_values_end(); ++inst_iter) {
171 Instruction* x = &*inst_iter;
172 if (inst->result_id() == x->result_id()) {
173 const_mgr->BuildInstructionAndAddToModule(new_constant, &inst_iter);
174 break;
175 }
176 }
177
178 composite = chain[i - 1];
179 components = composite->AsCompositeConstant()->GetComponents();
180 type = composite->AsCompositeConstant()->type();
181 ids.clear();
182 for (size_t k = 0; k < components.size(); k++) {
183 const uint32_t index =
184 inst->GetSingleWordInOperand(1 + static_cast<uint32_t>(i));
185 const analysis::Constant* constant =
186 (k == index) ? new_constant : components[k];
187 const uint32_t constant_id =
188 const_mgr->FindDeclaredConstant(constant, 0);
189 ids.push_back(constant_id);
190 }
191 new_constant = const_mgr->GetConstant(type, ids);
192 }
193
194 // If multiple constants were created, only need to return the top index.
195 return new_constant;
196 };
197}
198
Chris Forbescc5697f2019-01-30 11:54:08 -0800199ConstantFoldingRule FoldVectorShuffleWithConstants() {
200 return [](IRContext* context, Instruction* inst,
201 const std::vector<const analysis::Constant*>& constants)
202 -> const analysis::Constant* {
Nicolas Capens84c9c452022-11-18 14:11:05 +0000203 assert(inst->opcode() == spv::Op::OpVectorShuffle);
Chris Forbescc5697f2019-01-30 11:54:08 -0800204 const analysis::Constant* c1 = constants[0];
205 const analysis::Constant* c2 = constants[1];
206 if (c1 == nullptr || c2 == nullptr) {
207 return nullptr;
208 }
209
210 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
211 const analysis::Type* element_type = c1->type()->AsVector()->element_type();
212
213 std::vector<const analysis::Constant*> c1_components;
214 if (const analysis::VectorConstant* vec_const = c1->AsVectorConstant()) {
215 c1_components = vec_const->GetComponents();
216 } else {
217 assert(c1->AsNullConstant());
218 const analysis::Constant* element =
219 const_mgr->GetConstant(element_type, {});
220 c1_components.resize(c1->type()->AsVector()->element_count(), element);
221 }
222 std::vector<const analysis::Constant*> c2_components;
223 if (const analysis::VectorConstant* vec_const = c2->AsVectorConstant()) {
224 c2_components = vec_const->GetComponents();
225 } else {
226 assert(c2->AsNullConstant());
227 const analysis::Constant* element =
228 const_mgr->GetConstant(element_type, {});
229 c2_components.resize(c2->type()->AsVector()->element_count(), element);
230 }
231
232 std::vector<uint32_t> ids;
233 const uint32_t undef_literal_value = 0xffffffff;
234 for (uint32_t i = 2; i < inst->NumInOperands(); ++i) {
235 uint32_t index = inst->GetSingleWordInOperand(i);
236 if (index == undef_literal_value) {
237 // Don't fold shuffle with undef literal value.
238 return nullptr;
239 } else if (index < c1_components.size()) {
240 Instruction* member_inst =
241 const_mgr->GetDefiningInstruction(c1_components[index]);
242 ids.push_back(member_inst->result_id());
243 } else {
244 Instruction* member_inst = const_mgr->GetDefiningInstruction(
245 c2_components[index - c1_components.size()]);
246 ids.push_back(member_inst->result_id());
247 }
248 }
249
250 analysis::TypeManager* type_mgr = context->get_type_mgr();
251 return const_mgr->GetConstant(type_mgr->GetType(inst->type_id()), ids);
252 };
253}
254
255ConstantFoldingRule FoldVectorTimesScalar() {
256 return [](IRContext* context, Instruction* inst,
257 const std::vector<const analysis::Constant*>& constants)
258 -> const analysis::Constant* {
Nicolas Capens84c9c452022-11-18 14:11:05 +0000259 assert(inst->opcode() == spv::Op::OpVectorTimesScalar);
Chris Forbescc5697f2019-01-30 11:54:08 -0800260 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
261 analysis::TypeManager* type_mgr = context->get_type_mgr();
262
263 if (!inst->IsFloatingPointFoldingAllowed()) {
264 if (HasFloatingPoint(type_mgr->GetType(inst->type_id()))) {
265 return nullptr;
266 }
267 }
268
269 const analysis::Constant* c1 = constants[0];
270 const analysis::Constant* c2 = constants[1];
271
272 if (c1 && c1->IsZero()) {
273 return c1;
274 }
275
276 if (c2 && c2->IsZero()) {
277 // Get or create the NullConstant for this type.
278 std::vector<uint32_t> ids;
279 return const_mgr->GetConstant(type_mgr->GetType(inst->type_id()), ids);
280 }
281
282 if (c1 == nullptr || c2 == nullptr) {
283 return nullptr;
284 }
285
286 // Check result type.
287 const analysis::Type* result_type = type_mgr->GetType(inst->type_id());
288 const analysis::Vector* vector_type = result_type->AsVector();
289 assert(vector_type != nullptr);
290 const analysis::Type* element_type = vector_type->element_type();
291 assert(element_type != nullptr);
292 const analysis::Float* float_type = element_type->AsFloat();
293 assert(float_type != nullptr);
294
295 // Check types of c1 and c2.
296 assert(c1->type()->AsVector() == vector_type);
297 assert(c1->type()->AsVector()->element_type() == element_type &&
298 c2->type() == element_type);
299
300 // Get a float vector that is the result of vector-times-scalar.
301 std::vector<const analysis::Constant*> c1_components =
302 c1->GetVectorComponents(const_mgr);
303 std::vector<uint32_t> ids;
304 if (float_type->width() == 32) {
305 float scalar = c2->GetFloat();
306 for (uint32_t i = 0; i < c1_components.size(); ++i) {
307 utils::FloatProxy<float> result(c1_components[i]->GetFloat() * scalar);
308 std::vector<uint32_t> words = result.GetWords();
309 const analysis::Constant* new_elem =
310 const_mgr->GetConstant(float_type, words);
311 ids.push_back(const_mgr->GetDefiningInstruction(new_elem)->result_id());
312 }
313 return const_mgr->GetConstant(vector_type, ids);
314 } else if (float_type->width() == 64) {
315 double scalar = c2->GetDouble();
316 for (uint32_t i = 0; i < c1_components.size(); ++i) {
317 utils::FloatProxy<double> result(c1_components[i]->GetDouble() *
318 scalar);
319 std::vector<uint32_t> words = result.GetWords();
320 const analysis::Constant* new_elem =
321 const_mgr->GetConstant(float_type, words);
322 ids.push_back(const_mgr->GetDefiningInstruction(new_elem)->result_id());
323 }
324 return const_mgr->GetConstant(vector_type, ids);
325 }
326 return nullptr;
327 };
328}
329
Nicolas Capens00a1bcc2022-07-29 16:49:40 -0400330ConstantFoldingRule FoldVectorTimesMatrix() {
331 return [](IRContext* context, Instruction* inst,
332 const std::vector<const analysis::Constant*>& constants)
333 -> const analysis::Constant* {
Nicolas Capens84c9c452022-11-18 14:11:05 +0000334 assert(inst->opcode() == spv::Op::OpVectorTimesMatrix);
Nicolas Capens00a1bcc2022-07-29 16:49:40 -0400335 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
336 analysis::TypeManager* type_mgr = context->get_type_mgr();
337
338 if (!inst->IsFloatingPointFoldingAllowed()) {
339 if (HasFloatingPoint(type_mgr->GetType(inst->type_id()))) {
340 return nullptr;
341 }
342 }
343
344 const analysis::Constant* c1 = constants[0];
345 const analysis::Constant* c2 = constants[1];
346
347 if (c1 == nullptr || c2 == nullptr) {
348 return nullptr;
349 }
350
351 // Check result type.
352 const analysis::Type* result_type = type_mgr->GetType(inst->type_id());
353 const analysis::Vector* vector_type = result_type->AsVector();
354 assert(vector_type != nullptr);
355 const analysis::Type* element_type = vector_type->element_type();
356 assert(element_type != nullptr);
357 const analysis::Float* float_type = element_type->AsFloat();
358 assert(float_type != nullptr);
359
360 // Check types of c1 and c2.
361 assert(c1->type()->AsVector() == vector_type);
362 assert(c1->type()->AsVector()->element_type() == element_type &&
363 c2->type()->AsMatrix()->element_type() == vector_type);
364
365 // Get a float vector that is the result of vector-times-matrix.
366 std::vector<const analysis::Constant*> c1_components =
367 c1->GetVectorComponents(const_mgr);
368 std::vector<const analysis::Constant*> c2_components =
369 c2->AsMatrixConstant()->GetComponents();
370 uint32_t resultVectorSize = result_type->AsVector()->element_count();
371
372 std::vector<uint32_t> ids;
373
374 if ((c1 && c1->IsZero()) || (c2 && c2->IsZero())) {
375 std::vector<uint32_t> words(float_type->width() / 32, 0);
376 for (uint32_t i = 0; i < resultVectorSize; ++i) {
377 const analysis::Constant* new_elem =
378 const_mgr->GetConstant(float_type, words);
379 ids.push_back(const_mgr->GetDefiningInstruction(new_elem)->result_id());
380 }
381 return const_mgr->GetConstant(vector_type, ids);
382 }
383
384 if (float_type->width() == 32) {
385 for (uint32_t i = 0; i < resultVectorSize; ++i) {
386 float result_scalar = 0.0f;
387 const analysis::VectorConstant* c2_vec =
388 c2_components[i]->AsVectorConstant();
389 for (uint32_t j = 0; j < c2_vec->GetComponents().size(); ++j) {
390 float c1_scalar = c1_components[j]->GetFloat();
391 float c2_scalar = c2_vec->GetComponents()[j]->GetFloat();
392 result_scalar += c1_scalar * c2_scalar;
393 }
394 utils::FloatProxy<float> result(result_scalar);
395 std::vector<uint32_t> words = result.GetWords();
396 const analysis::Constant* new_elem =
397 const_mgr->GetConstant(float_type, words);
398 ids.push_back(const_mgr->GetDefiningInstruction(new_elem)->result_id());
399 }
400 return const_mgr->GetConstant(vector_type, ids);
401 } else if (float_type->width() == 64) {
402 for (uint32_t i = 0; i < c2_components.size(); ++i) {
403 double result_scalar = 0.0;
404 const analysis::VectorConstant* c2_vec =
405 c2_components[i]->AsVectorConstant();
406 for (uint32_t j = 0; j < c2_vec->GetComponents().size(); ++j) {
407 double c1_scalar = c1_components[j]->GetDouble();
408 double c2_scalar = c2_vec->GetComponents()[j]->GetDouble();
409 result_scalar += c1_scalar * c2_scalar;
410 }
411 utils::FloatProxy<double> result(result_scalar);
412 std::vector<uint32_t> words = result.GetWords();
413 const analysis::Constant* new_elem =
414 const_mgr->GetConstant(float_type, words);
415 ids.push_back(const_mgr->GetDefiningInstruction(new_elem)->result_id());
416 }
417 return const_mgr->GetConstant(vector_type, ids);
418 }
419 return nullptr;
420 };
421}
422
423ConstantFoldingRule FoldMatrixTimesVector() {
424 return [](IRContext* context, Instruction* inst,
425 const std::vector<const analysis::Constant*>& constants)
426 -> const analysis::Constant* {
Nicolas Capens84c9c452022-11-18 14:11:05 +0000427 assert(inst->opcode() == spv::Op::OpMatrixTimesVector);
Nicolas Capens00a1bcc2022-07-29 16:49:40 -0400428 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
429 analysis::TypeManager* type_mgr = context->get_type_mgr();
430
431 if (!inst->IsFloatingPointFoldingAllowed()) {
432 if (HasFloatingPoint(type_mgr->GetType(inst->type_id()))) {
433 return nullptr;
434 }
435 }
436
437 const analysis::Constant* c1 = constants[0];
438 const analysis::Constant* c2 = constants[1];
439
440 if (c1 == nullptr || c2 == nullptr) {
441 return nullptr;
442 }
443
444 // Check result type.
445 const analysis::Type* result_type = type_mgr->GetType(inst->type_id());
446 const analysis::Vector* vector_type = result_type->AsVector();
447 assert(vector_type != nullptr);
448 const analysis::Type* element_type = vector_type->element_type();
449 assert(element_type != nullptr);
450 const analysis::Float* float_type = element_type->AsFloat();
451 assert(float_type != nullptr);
452
453 // Check types of c1 and c2.
454 assert(c1->type()->AsMatrix()->element_type() == vector_type);
455 assert(c2->type()->AsVector()->element_type() == element_type);
456
457 // Get a float vector that is the result of matrix-times-vector.
458 std::vector<const analysis::Constant*> c1_components =
459 c1->AsMatrixConstant()->GetComponents();
460 std::vector<const analysis::Constant*> c2_components =
461 c2->GetVectorComponents(const_mgr);
462 uint32_t resultVectorSize = result_type->AsVector()->element_count();
463
464 std::vector<uint32_t> ids;
465
466 if ((c1 && c1->IsZero()) || (c2 && c2->IsZero())) {
467 std::vector<uint32_t> words(float_type->width() / 32, 0);
468 for (uint32_t i = 0; i < resultVectorSize; ++i) {
469 const analysis::Constant* new_elem =
470 const_mgr->GetConstant(float_type, words);
471 ids.push_back(const_mgr->GetDefiningInstruction(new_elem)->result_id());
472 }
473 return const_mgr->GetConstant(vector_type, ids);
474 }
475
476 if (float_type->width() == 32) {
477 for (uint32_t i = 0; i < resultVectorSize; ++i) {
478 float result_scalar = 0.0f;
479 for (uint32_t j = 0; j < c1_components.size(); ++j) {
480 float c1_scalar = c1_components[j]
481 ->AsVectorConstant()
482 ->GetComponents()[i]
483 ->GetFloat();
484 float c2_scalar = c2_components[j]->GetFloat();
485 result_scalar += c1_scalar * c2_scalar;
486 }
487 utils::FloatProxy<float> result(result_scalar);
488 std::vector<uint32_t> words = result.GetWords();
489 const analysis::Constant* new_elem =
490 const_mgr->GetConstant(float_type, words);
491 ids.push_back(const_mgr->GetDefiningInstruction(new_elem)->result_id());
492 }
493 return const_mgr->GetConstant(vector_type, ids);
494 } else if (float_type->width() == 64) {
495 for (uint32_t i = 0; i < resultVectorSize; ++i) {
496 double result_scalar = 0.0;
497 for (uint32_t j = 0; j < c1_components.size(); ++j) {
498 double c1_scalar = c1_components[j]
499 ->AsVectorConstant()
500 ->GetComponents()[i]
501 ->GetDouble();
502 double c2_scalar = c2_components[j]->GetDouble();
503 result_scalar += c1_scalar * c2_scalar;
504 }
505 utils::FloatProxy<double> result(result_scalar);
506 std::vector<uint32_t> words = result.GetWords();
507 const analysis::Constant* new_elem =
508 const_mgr->GetConstant(float_type, words);
509 ids.push_back(const_mgr->GetDefiningInstruction(new_elem)->result_id());
510 }
511 return const_mgr->GetConstant(vector_type, ids);
512 }
513 return nullptr;
514 };
515}
516
Chris Forbescc5697f2019-01-30 11:54:08 -0800517ConstantFoldingRule FoldCompositeWithConstants() {
518 // Folds an OpCompositeConstruct where all of the inputs are constants to a
519 // constant. A new constant is created if necessary.
520 return [](IRContext* context, Instruction* inst,
521 const std::vector<const analysis::Constant*>& constants)
522 -> const analysis::Constant* {
523 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
524 analysis::TypeManager* type_mgr = context->get_type_mgr();
525 const analysis::Type* new_type = type_mgr->GetType(inst->type_id());
526 Instruction* type_inst =
527 context->get_def_use_mgr()->GetDef(inst->type_id());
528
529 std::vector<uint32_t> ids;
530 for (uint32_t i = 0; i < constants.size(); ++i) {
531 const analysis::Constant* element_const = constants[i];
532 if (element_const == nullptr) {
533 return nullptr;
534 }
535
536 uint32_t component_type_id = 0;
Nicolas Capens84c9c452022-11-18 14:11:05 +0000537 if (type_inst->opcode() == spv::Op::OpTypeStruct) {
Chris Forbescc5697f2019-01-30 11:54:08 -0800538 component_type_id = type_inst->GetSingleWordInOperand(i);
Nicolas Capens84c9c452022-11-18 14:11:05 +0000539 } else if (type_inst->opcode() == spv::Op::OpTypeArray) {
Chris Forbescc5697f2019-01-30 11:54:08 -0800540 component_type_id = type_inst->GetSingleWordInOperand(0);
541 }
542
543 uint32_t element_id =
544 const_mgr->FindDeclaredConstant(element_const, component_type_id);
545 if (element_id == 0) {
546 return nullptr;
547 }
548 ids.push_back(element_id);
549 }
550 return const_mgr->GetConstant(new_type, ids);
551 };
552}
553
554// The interface for a function that returns the result of applying a scalar
555// floating-point binary operation on |a| and |b|. The type of the return value
556// will be |type|. The input constants must also be of type |type|.
557using UnaryScalarFoldingRule = std::function<const analysis::Constant*(
558 const analysis::Type* result_type, const analysis::Constant* a,
559 analysis::ConstantManager*)>;
560
561// The interface for a function that returns the result of applying a scalar
562// floating-point binary operation on |a| and |b|. The type of the return value
563// will be |type|. The input constants must also be of type |type|.
564using BinaryScalarFoldingRule = std::function<const analysis::Constant*(
565 const analysis::Type* result_type, const analysis::Constant* a,
566 const analysis::Constant* b, analysis::ConstantManager*)>;
567
568// Returns a |ConstantFoldingRule| that folds unary floating point scalar ops
569// using |scalar_rule| and unary float point vectors ops by applying
570// |scalar_rule| to the elements of the vector. The |ConstantFoldingRule|
571// that is returned assumes that |constants| contains 1 entry. If they are
572// not |nullptr|, then their type is either |Float| or |Integer| or a |Vector|
573// whose element type is |Float| or |Integer|.
574ConstantFoldingRule FoldFPUnaryOp(UnaryScalarFoldingRule scalar_rule) {
575 return [scalar_rule](IRContext* context, Instruction* inst,
576 const std::vector<const analysis::Constant*>& constants)
577 -> const analysis::Constant* {
578 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
579 analysis::TypeManager* type_mgr = context->get_type_mgr();
580 const analysis::Type* result_type = type_mgr->GetType(inst->type_id());
581 const analysis::Vector* vector_type = result_type->AsVector();
582
583 if (!inst->IsFloatingPointFoldingAllowed()) {
584 return nullptr;
585 }
586
Ben Claytondc6b76a2020-02-24 14:53:40 +0000587 const analysis::Constant* arg =
Nicolas Capens84c9c452022-11-18 14:11:05 +0000588 (inst->opcode() == spv::Op::OpExtInst) ? constants[1] : constants[0];
Ben Claytondc6b76a2020-02-24 14:53:40 +0000589
590 if (arg == nullptr) {
Chris Forbescc5697f2019-01-30 11:54:08 -0800591 return nullptr;
592 }
593
594 if (vector_type != nullptr) {
595 std::vector<const analysis::Constant*> a_components;
596 std::vector<const analysis::Constant*> results_components;
597
Ben Claytondc6b76a2020-02-24 14:53:40 +0000598 a_components = arg->GetVectorComponents(const_mgr);
Chris Forbescc5697f2019-01-30 11:54:08 -0800599
600 // Fold each component of the vector.
601 for (uint32_t i = 0; i < a_components.size(); ++i) {
602 results_components.push_back(scalar_rule(vector_type->element_type(),
603 a_components[i], const_mgr));
604 if (results_components[i] == nullptr) {
605 return nullptr;
606 }
607 }
608
609 // Build the constant object and return it.
610 std::vector<uint32_t> ids;
611 for (const analysis::Constant* member : results_components) {
612 ids.push_back(const_mgr->GetDefiningInstruction(member)->result_id());
613 }
614 return const_mgr->GetConstant(vector_type, ids);
615 } else {
Ben Claytondc6b76a2020-02-24 14:53:40 +0000616 return scalar_rule(result_type, arg, const_mgr);
Chris Forbescc5697f2019-01-30 11:54:08 -0800617 }
618 };
619}
620
Ben Claytond552f632019-11-18 11:18:41 +0000621// Returns the result of folding the constants in |constants| according the
622// |scalar_rule|. If |result_type| is a vector, then |scalar_rule| is applied
623// per component.
624const analysis::Constant* FoldFPBinaryOp(
625 BinaryScalarFoldingRule scalar_rule, uint32_t result_type_id,
626 const std::vector<const analysis::Constant*>& constants,
627 IRContext* context) {
628 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
629 analysis::TypeManager* type_mgr = context->get_type_mgr();
630 const analysis::Type* result_type = type_mgr->GetType(result_type_id);
631 const analysis::Vector* vector_type = result_type->AsVector();
632
633 if (constants[0] == nullptr || constants[1] == nullptr) {
634 return nullptr;
635 }
636
637 if (vector_type != nullptr) {
638 std::vector<const analysis::Constant*> a_components;
639 std::vector<const analysis::Constant*> b_components;
640 std::vector<const analysis::Constant*> results_components;
641
642 a_components = constants[0]->GetVectorComponents(const_mgr);
643 b_components = constants[1]->GetVectorComponents(const_mgr);
644
645 // Fold each component of the vector.
646 for (uint32_t i = 0; i < a_components.size(); ++i) {
647 results_components.push_back(scalar_rule(vector_type->element_type(),
648 a_components[i], b_components[i],
649 const_mgr));
650 if (results_components[i] == nullptr) {
651 return nullptr;
652 }
653 }
654
655 // Build the constant object and return it.
656 std::vector<uint32_t> ids;
657 for (const analysis::Constant* member : results_components) {
658 ids.push_back(const_mgr->GetDefiningInstruction(member)->result_id());
659 }
660 return const_mgr->GetConstant(vector_type, ids);
661 } else {
662 return scalar_rule(result_type, constants[0], constants[1], const_mgr);
663 }
664}
665
Chris Forbescc5697f2019-01-30 11:54:08 -0800666// Returns a |ConstantFoldingRule| that folds floating point scalars using
667// |scalar_rule| and vectors of floating point by applying |scalar_rule| to the
668// elements of the vector. The |ConstantFoldingRule| that is returned assumes
669// that |constants| contains 2 entries. If they are not |nullptr|, then their
670// type is either |Float| or a |Vector| whose element type is |Float|.
671ConstantFoldingRule FoldFPBinaryOp(BinaryScalarFoldingRule scalar_rule) {
672 return [scalar_rule](IRContext* context, Instruction* inst,
673 const std::vector<const analysis::Constant*>& constants)
674 -> const analysis::Constant* {
Chris Forbescc5697f2019-01-30 11:54:08 -0800675 if (!inst->IsFloatingPointFoldingAllowed()) {
676 return nullptr;
677 }
Nicolas Capens84c9c452022-11-18 14:11:05 +0000678 if (inst->opcode() == spv::Op::OpExtInst) {
Ben Claytond552f632019-11-18 11:18:41 +0000679 return FoldFPBinaryOp(scalar_rule, inst->type_id(),
680 {constants[1], constants[2]}, context);
Chris Forbescc5697f2019-01-30 11:54:08 -0800681 }
Ben Claytond552f632019-11-18 11:18:41 +0000682 return FoldFPBinaryOp(scalar_rule, inst->type_id(), constants, context);
Chris Forbescc5697f2019-01-30 11:54:08 -0800683 };
684}
685
686// This macro defines a |UnaryScalarFoldingRule| that performs float to
687// integer conversion.
688// TODO(greg-lunarg): Support for 64-bit integer types.
689UnaryScalarFoldingRule FoldFToIOp() {
690 return [](const analysis::Type* result_type, const analysis::Constant* a,
691 analysis::ConstantManager* const_mgr) -> const analysis::Constant* {
692 assert(result_type != nullptr && a != nullptr);
693 const analysis::Integer* integer_type = result_type->AsInteger();
694 const analysis::Float* float_type = a->type()->AsFloat();
695 assert(float_type != nullptr);
696 assert(integer_type != nullptr);
697 if (integer_type->width() != 32) return nullptr;
698 if (float_type->width() == 32) {
699 float fa = a->GetFloat();
700 uint32_t result = integer_type->IsSigned()
701 ? static_cast<uint32_t>(static_cast<int32_t>(fa))
702 : static_cast<uint32_t>(fa);
703 std::vector<uint32_t> words = {result};
704 return const_mgr->GetConstant(result_type, words);
705 } else if (float_type->width() == 64) {
706 double fa = a->GetDouble();
707 uint32_t result = integer_type->IsSigned()
708 ? static_cast<uint32_t>(static_cast<int32_t>(fa))
709 : static_cast<uint32_t>(fa);
710 std::vector<uint32_t> words = {result};
711 return const_mgr->GetConstant(result_type, words);
712 }
713 return nullptr;
714 };
715}
716
717// This function defines a |UnaryScalarFoldingRule| that performs integer to
718// float conversion.
719// TODO(greg-lunarg): Support for 64-bit integer types.
720UnaryScalarFoldingRule FoldIToFOp() {
721 return [](const analysis::Type* result_type, const analysis::Constant* a,
722 analysis::ConstantManager* const_mgr) -> const analysis::Constant* {
723 assert(result_type != nullptr && a != nullptr);
724 const analysis::Integer* integer_type = a->type()->AsInteger();
725 const analysis::Float* float_type = result_type->AsFloat();
726 assert(float_type != nullptr);
727 assert(integer_type != nullptr);
728 if (integer_type->width() != 32) return nullptr;
729 uint32_t ua = a->GetU32();
730 if (float_type->width() == 32) {
731 float result_val = integer_type->IsSigned()
732 ? static_cast<float>(static_cast<int32_t>(ua))
733 : static_cast<float>(ua);
734 utils::FloatProxy<float> result(result_val);
735 std::vector<uint32_t> words = {result.data()};
736 return const_mgr->GetConstant(result_type, words);
737 } else if (float_type->width() == 64) {
738 double result_val = integer_type->IsSigned()
739 ? static_cast<double>(static_cast<int32_t>(ua))
740 : static_cast<double>(ua);
741 utils::FloatProxy<double> result(result_val);
742 std::vector<uint32_t> words = result.GetWords();
743 return const_mgr->GetConstant(result_type, words);
744 }
745 return nullptr;
746 };
747}
748
Ben Claytonb73b7602019-07-29 13:56:13 +0100749// This defines a |UnaryScalarFoldingRule| that performs |OpQuantizeToF16|.
750UnaryScalarFoldingRule FoldQuantizeToF16Scalar() {
751 return [](const analysis::Type* result_type, const analysis::Constant* a,
752 analysis::ConstantManager* const_mgr) -> const analysis::Constant* {
753 assert(result_type != nullptr && a != nullptr);
754 const analysis::Float* float_type = a->type()->AsFloat();
755 assert(float_type != nullptr);
756 if (float_type->width() != 32) {
757 return nullptr;
758 }
759
760 float fa = a->GetFloat();
761 utils::HexFloat<utils::FloatProxy<float>> orignal(fa);
762 utils::HexFloat<utils::FloatProxy<utils::Float16>> quantized(0);
763 utils::HexFloat<utils::FloatProxy<float>> result(0.0f);
764 orignal.castTo(quantized, utils::round_direction::kToZero);
765 quantized.castTo(result, utils::round_direction::kToZero);
766 std::vector<uint32_t> words = {result.getBits()};
767 return const_mgr->GetConstant(result_type, words);
768 };
769}
770
Chris Forbescc5697f2019-01-30 11:54:08 -0800771// This macro defines a |BinaryScalarFoldingRule| that applies |op|. The
772// operator |op| must work for both float and double, and use syntax "f1 op f2".
Ben Claytond552f632019-11-18 11:18:41 +0000773#define FOLD_FPARITH_OP(op) \
774 [](const analysis::Type* result_type_in_macro, const analysis::Constant* a, \
775 const analysis::Constant* b, \
776 analysis::ConstantManager* const_mgr_in_macro) \
777 -> const analysis::Constant* { \
778 assert(result_type_in_macro != nullptr && a != nullptr && b != nullptr); \
779 assert(result_type_in_macro == a->type() && \
780 result_type_in_macro == b->type()); \
781 const analysis::Float* float_type_in_macro = \
782 result_type_in_macro->AsFloat(); \
783 assert(float_type_in_macro != nullptr); \
784 if (float_type_in_macro->width() == 32) { \
785 float fa = a->GetFloat(); \
786 float fb = b->GetFloat(); \
787 utils::FloatProxy<float> result_in_macro(fa op fb); \
788 std::vector<uint32_t> words_in_macro = result_in_macro.GetWords(); \
789 return const_mgr_in_macro->GetConstant(result_type_in_macro, \
790 words_in_macro); \
791 } else if (float_type_in_macro->width() == 64) { \
792 double fa = a->GetDouble(); \
793 double fb = b->GetDouble(); \
794 utils::FloatProxy<double> result_in_macro(fa op fb); \
795 std::vector<uint32_t> words_in_macro = result_in_macro.GetWords(); \
796 return const_mgr_in_macro->GetConstant(result_type_in_macro, \
797 words_in_macro); \
798 } \
799 return nullptr; \
Chris Forbescc5697f2019-01-30 11:54:08 -0800800 }
801
802// Define the folding rule for conversion between floating point and integer
803ConstantFoldingRule FoldFToI() { return FoldFPUnaryOp(FoldFToIOp()); }
804ConstantFoldingRule FoldIToF() { return FoldFPUnaryOp(FoldIToFOp()); }
Ben Claytonb73b7602019-07-29 13:56:13 +0100805ConstantFoldingRule FoldQuantizeToF16() {
806 return FoldFPUnaryOp(FoldQuantizeToF16Scalar());
807}
Chris Forbescc5697f2019-01-30 11:54:08 -0800808
809// Define the folding rules for subtraction, addition, multiplication, and
810// division for floating point values.
811ConstantFoldingRule FoldFSub() { return FoldFPBinaryOp(FOLD_FPARITH_OP(-)); }
812ConstantFoldingRule FoldFAdd() { return FoldFPBinaryOp(FOLD_FPARITH_OP(+)); }
813ConstantFoldingRule FoldFMul() { return FoldFPBinaryOp(FOLD_FPARITH_OP(*)); }
Nicolas Capens6cacf182021-11-30 11:15:46 -0500814
815// Returns the constant that results from evaluating |numerator| / 0.0. Returns
sugoi1b398bf32022-02-18 10:27:28 -0500816// |nullptr| if the result could not be evaluated.
Nicolas Capens6cacf182021-11-30 11:15:46 -0500817const analysis::Constant* FoldFPScalarDivideByZero(
818 const analysis::Type* result_type, const analysis::Constant* numerator,
819 analysis::ConstantManager* const_mgr) {
820 if (numerator == nullptr) {
821 return nullptr;
822 }
823
824 if (numerator->IsZero()) {
825 return GetNan(result_type, const_mgr);
826 }
827
828 const analysis::Constant* result = GetInf(result_type, const_mgr);
829 if (result == nullptr) {
830 return nullptr;
831 }
832
833 if (numerator->AsFloatConstant()->GetValueAsDouble() < 0.0) {
Nicolas Capens84c9c452022-11-18 14:11:05 +0000834 result = NegateFPConst(result_type, result, const_mgr);
Nicolas Capens6cacf182021-11-30 11:15:46 -0500835 }
836 return result;
837}
838
839// Returns the result of folding |numerator| / |denominator|. Returns |nullptr|
840// if it cannot be folded.
841const analysis::Constant* FoldScalarFPDivide(
842 const analysis::Type* result_type, const analysis::Constant* numerator,
843 const analysis::Constant* denominator,
844 analysis::ConstantManager* const_mgr) {
845 if (denominator == nullptr) {
846 return nullptr;
847 }
848
849 if (denominator->IsZero()) {
850 return FoldFPScalarDivideByZero(result_type, numerator, const_mgr);
851 }
852
853 const analysis::FloatConstant* denominator_float =
854 denominator->AsFloatConstant();
855 if (denominator_float && denominator->GetValueAsDouble() == -0.0) {
856 const analysis::Constant* result =
857 FoldFPScalarDivideByZero(result_type, numerator, const_mgr);
858 if (result != nullptr)
Nicolas Capens84c9c452022-11-18 14:11:05 +0000859 result = NegateFPConst(result_type, result, const_mgr);
Nicolas Capens6cacf182021-11-30 11:15:46 -0500860 return result;
861 } else {
862 return FOLD_FPARITH_OP(/)(result_type, numerator, denominator, const_mgr);
863 }
864}
865
866// Returns the constant folding rule to fold |OpFDiv| with two constants.
867ConstantFoldingRule FoldFDiv() { return FoldFPBinaryOp(FoldScalarFPDivide); }
Chris Forbescc5697f2019-01-30 11:54:08 -0800868
869bool CompareFloatingPoint(bool op_result, bool op_unordered,
870 bool need_ordered) {
871 if (need_ordered) {
872 // operands are ordered and Operand 1 is |op| Operand 2
873 return !op_unordered && op_result;
874 } else {
875 // operands are unordered or Operand 1 is |op| Operand 2
876 return op_unordered || op_result;
877 }
878}
879
880// This macro defines a |BinaryScalarFoldingRule| that applies |op|. The
881// operator |op| must work for both float and double, and use syntax "f1 op f2".
882#define FOLD_FPCMP_OP(op, ord) \
883 [](const analysis::Type* result_type, const analysis::Constant* a, \
884 const analysis::Constant* b, \
885 analysis::ConstantManager* const_mgr) -> const analysis::Constant* { \
886 assert(result_type != nullptr && a != nullptr && b != nullptr); \
887 assert(result_type->AsBool()); \
888 assert(a->type() == b->type()); \
889 const analysis::Float* float_type = a->type()->AsFloat(); \
890 assert(float_type != nullptr); \
891 if (float_type->width() == 32) { \
892 float fa = a->GetFloat(); \
893 float fb = b->GetFloat(); \
894 bool result = CompareFloatingPoint( \
895 fa op fb, std::isnan(fa) || std::isnan(fb), ord); \
896 std::vector<uint32_t> words = {uint32_t(result)}; \
897 return const_mgr->GetConstant(result_type, words); \
898 } else if (float_type->width() == 64) { \
899 double fa = a->GetDouble(); \
900 double fb = b->GetDouble(); \
901 bool result = CompareFloatingPoint( \
902 fa op fb, std::isnan(fa) || std::isnan(fb), ord); \
903 std::vector<uint32_t> words = {uint32_t(result)}; \
904 return const_mgr->GetConstant(result_type, words); \
905 } \
906 return nullptr; \
907 }
908
909// Define the folding rules for ordered and unordered comparison for floating
910// point values.
911ConstantFoldingRule FoldFOrdEqual() {
912 return FoldFPBinaryOp(FOLD_FPCMP_OP(==, true));
913}
914ConstantFoldingRule FoldFUnordEqual() {
915 return FoldFPBinaryOp(FOLD_FPCMP_OP(==, false));
916}
917ConstantFoldingRule FoldFOrdNotEqual() {
918 return FoldFPBinaryOp(FOLD_FPCMP_OP(!=, true));
919}
920ConstantFoldingRule FoldFUnordNotEqual() {
921 return FoldFPBinaryOp(FOLD_FPCMP_OP(!=, false));
922}
923ConstantFoldingRule FoldFOrdLessThan() {
924 return FoldFPBinaryOp(FOLD_FPCMP_OP(<, true));
925}
926ConstantFoldingRule FoldFUnordLessThan() {
927 return FoldFPBinaryOp(FOLD_FPCMP_OP(<, false));
928}
929ConstantFoldingRule FoldFOrdGreaterThan() {
930 return FoldFPBinaryOp(FOLD_FPCMP_OP(>, true));
931}
932ConstantFoldingRule FoldFUnordGreaterThan() {
933 return FoldFPBinaryOp(FOLD_FPCMP_OP(>, false));
934}
935ConstantFoldingRule FoldFOrdLessThanEqual() {
936 return FoldFPBinaryOp(FOLD_FPCMP_OP(<=, true));
937}
938ConstantFoldingRule FoldFUnordLessThanEqual() {
939 return FoldFPBinaryOp(FOLD_FPCMP_OP(<=, false));
940}
941ConstantFoldingRule FoldFOrdGreaterThanEqual() {
942 return FoldFPBinaryOp(FOLD_FPCMP_OP(>=, true));
943}
944ConstantFoldingRule FoldFUnordGreaterThanEqual() {
945 return FoldFPBinaryOp(FOLD_FPCMP_OP(>=, false));
946}
947
948// Folds an OpDot where all of the inputs are constants to a
949// constant. A new constant is created if necessary.
950ConstantFoldingRule FoldOpDotWithConstants() {
951 return [](IRContext* context, Instruction* inst,
952 const std::vector<const analysis::Constant*>& constants)
953 -> const analysis::Constant* {
954 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
955 analysis::TypeManager* type_mgr = context->get_type_mgr();
956 const analysis::Type* new_type = type_mgr->GetType(inst->type_id());
957 assert(new_type->AsFloat() && "OpDot should have a float return type.");
958 const analysis::Float* float_type = new_type->AsFloat();
959
960 if (!inst->IsFloatingPointFoldingAllowed()) {
961 return nullptr;
962 }
963
964 // If one of the operands is 0, then the result is 0.
965 bool has_zero_operand = false;
966
967 for (int i = 0; i < 2; ++i) {
968 if (constants[i]) {
969 if (constants[i]->AsNullConstant() ||
970 constants[i]->AsVectorConstant()->IsZero()) {
971 has_zero_operand = true;
972 break;
973 }
974 }
975 }
976
977 if (has_zero_operand) {
978 if (float_type->width() == 32) {
979 utils::FloatProxy<float> result(0.0f);
980 std::vector<uint32_t> words = result.GetWords();
981 return const_mgr->GetConstant(float_type, words);
982 }
983 if (float_type->width() == 64) {
984 utils::FloatProxy<double> result(0.0);
985 std::vector<uint32_t> words = result.GetWords();
986 return const_mgr->GetConstant(float_type, words);
987 }
988 return nullptr;
989 }
990
991 if (constants[0] == nullptr || constants[1] == nullptr) {
992 return nullptr;
993 }
994
995 std::vector<const analysis::Constant*> a_components;
996 std::vector<const analysis::Constant*> b_components;
997
998 a_components = constants[0]->GetVectorComponents(const_mgr);
999 b_components = constants[1]->GetVectorComponents(const_mgr);
1000
1001 utils::FloatProxy<double> result(0.0);
1002 std::vector<uint32_t> words = result.GetWords();
1003 const analysis::Constant* result_const =
1004 const_mgr->GetConstant(float_type, words);
Ben Claytonb73b7602019-07-29 13:56:13 +01001005 for (uint32_t i = 0; i < a_components.size() && result_const != nullptr;
1006 ++i) {
Chris Forbescc5697f2019-01-30 11:54:08 -08001007 if (a_components[i] == nullptr || b_components[i] == nullptr) {
1008 return nullptr;
1009 }
1010
1011 const analysis::Constant* component = FOLD_FPARITH_OP(*)(
1012 new_type, a_components[i], b_components[i], const_mgr);
Ben Claytonb73b7602019-07-29 13:56:13 +01001013 if (component == nullptr) {
1014 return nullptr;
1015 }
Chris Forbescc5697f2019-01-30 11:54:08 -08001016 result_const =
1017 FOLD_FPARITH_OP(+)(new_type, result_const, component, const_mgr);
1018 }
1019 return result_const;
1020 };
1021}
1022
1023// This function defines a |UnaryScalarFoldingRule| that subtracts the constant
1024// from zero.
1025UnaryScalarFoldingRule FoldFNegateOp() {
1026 return [](const analysis::Type* result_type, const analysis::Constant* a,
1027 analysis::ConstantManager* const_mgr) -> const analysis::Constant* {
1028 assert(result_type != nullptr && a != nullptr);
1029 assert(result_type == a->type());
Nicolas Capens84c9c452022-11-18 14:11:05 +00001030 return NegateFPConst(result_type, a, const_mgr);
Chris Forbescc5697f2019-01-30 11:54:08 -08001031 };
1032}
1033
1034ConstantFoldingRule FoldFNegate() { return FoldFPUnaryOp(FoldFNegateOp()); }
1035
Nicolas Capens84c9c452022-11-18 14:11:05 +00001036ConstantFoldingRule FoldFClampFeedingCompare(spv::Op cmp_opcode) {
Chris Forbescc5697f2019-01-30 11:54:08 -08001037 return [cmp_opcode](IRContext* context, Instruction* inst,
1038 const std::vector<const analysis::Constant*>& constants)
1039 -> const analysis::Constant* {
1040 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
1041 analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
1042
1043 if (!inst->IsFloatingPointFoldingAllowed()) {
1044 return nullptr;
1045 }
1046
1047 uint32_t non_const_idx = (constants[0] ? 1 : 0);
1048 uint32_t operand_id = inst->GetSingleWordInOperand(non_const_idx);
1049 Instruction* operand_inst = def_use_mgr->GetDef(operand_id);
1050
1051 analysis::TypeManager* type_mgr = context->get_type_mgr();
1052 const analysis::Type* operand_type =
1053 type_mgr->GetType(operand_inst->type_id());
1054
1055 if (!operand_type->AsFloat()) {
1056 return nullptr;
1057 }
1058
1059 if (operand_type->AsFloat()->width() != 32 &&
1060 operand_type->AsFloat()->width() != 64) {
1061 return nullptr;
1062 }
1063
Nicolas Capens84c9c452022-11-18 14:11:05 +00001064 if (operand_inst->opcode() != spv::Op::OpExtInst) {
Chris Forbescc5697f2019-01-30 11:54:08 -08001065 return nullptr;
1066 }
1067
1068 if (operand_inst->GetSingleWordInOperand(1) != GLSLstd450FClamp) {
1069 return nullptr;
1070 }
1071
1072 if (constants[1] == nullptr && constants[0] == nullptr) {
1073 return nullptr;
1074 }
1075
1076 uint32_t max_id = operand_inst->GetSingleWordInOperand(4);
1077 const analysis::Constant* max_const =
1078 const_mgr->FindDeclaredConstant(max_id);
1079
1080 uint32_t min_id = operand_inst->GetSingleWordInOperand(3);
1081 const analysis::Constant* min_const =
1082 const_mgr->FindDeclaredConstant(min_id);
1083
1084 bool found_result = false;
1085 bool result = false;
1086
1087 switch (cmp_opcode) {
Nicolas Capens84c9c452022-11-18 14:11:05 +00001088 case spv::Op::OpFOrdLessThan:
1089 case spv::Op::OpFUnordLessThan:
1090 case spv::Op::OpFOrdGreaterThanEqual:
1091 case spv::Op::OpFUnordGreaterThanEqual:
Chris Forbescc5697f2019-01-30 11:54:08 -08001092 if (constants[0]) {
1093 if (min_const) {
1094 if (constants[0]->GetValueAsDouble() <
1095 min_const->GetValueAsDouble()) {
1096 found_result = true;
Nicolas Capens84c9c452022-11-18 14:11:05 +00001097 result = (cmp_opcode == spv::Op::OpFOrdLessThan ||
1098 cmp_opcode == spv::Op::OpFUnordLessThan);
Chris Forbescc5697f2019-01-30 11:54:08 -08001099 }
1100 }
1101 if (max_const) {
1102 if (constants[0]->GetValueAsDouble() >=
1103 max_const->GetValueAsDouble()) {
1104 found_result = true;
Nicolas Capens84c9c452022-11-18 14:11:05 +00001105 result = !(cmp_opcode == spv::Op::OpFOrdLessThan ||
1106 cmp_opcode == spv::Op::OpFUnordLessThan);
Chris Forbescc5697f2019-01-30 11:54:08 -08001107 }
1108 }
1109 }
1110
1111 if (constants[1]) {
1112 if (max_const) {
1113 if (max_const->GetValueAsDouble() <
1114 constants[1]->GetValueAsDouble()) {
1115 found_result = true;
Nicolas Capens84c9c452022-11-18 14:11:05 +00001116 result = (cmp_opcode == spv::Op::OpFOrdLessThan ||
1117 cmp_opcode == spv::Op::OpFUnordLessThan);
Chris Forbescc5697f2019-01-30 11:54:08 -08001118 }
1119 }
1120
1121 if (min_const) {
1122 if (min_const->GetValueAsDouble() >=
1123 constants[1]->GetValueAsDouble()) {
1124 found_result = true;
Nicolas Capens84c9c452022-11-18 14:11:05 +00001125 result = !(cmp_opcode == spv::Op::OpFOrdLessThan ||
1126 cmp_opcode == spv::Op::OpFUnordLessThan);
Chris Forbescc5697f2019-01-30 11:54:08 -08001127 }
1128 }
1129 }
1130 break;
Nicolas Capens84c9c452022-11-18 14:11:05 +00001131 case spv::Op::OpFOrdGreaterThan:
1132 case spv::Op::OpFUnordGreaterThan:
1133 case spv::Op::OpFOrdLessThanEqual:
1134 case spv::Op::OpFUnordLessThanEqual:
Chris Forbescc5697f2019-01-30 11:54:08 -08001135 if (constants[0]) {
1136 if (min_const) {
1137 if (constants[0]->GetValueAsDouble() <=
1138 min_const->GetValueAsDouble()) {
1139 found_result = true;
Nicolas Capens84c9c452022-11-18 14:11:05 +00001140 result = (cmp_opcode == spv::Op::OpFOrdLessThanEqual ||
1141 cmp_opcode == spv::Op::OpFUnordLessThanEqual);
Chris Forbescc5697f2019-01-30 11:54:08 -08001142 }
1143 }
1144 if (max_const) {
1145 if (constants[0]->GetValueAsDouble() >
1146 max_const->GetValueAsDouble()) {
1147 found_result = true;
Nicolas Capens84c9c452022-11-18 14:11:05 +00001148 result = !(cmp_opcode == spv::Op::OpFOrdLessThanEqual ||
1149 cmp_opcode == spv::Op::OpFUnordLessThanEqual);
Chris Forbescc5697f2019-01-30 11:54:08 -08001150 }
1151 }
1152 }
1153
1154 if (constants[1]) {
1155 if (max_const) {
1156 if (max_const->GetValueAsDouble() <=
1157 constants[1]->GetValueAsDouble()) {
1158 found_result = true;
Nicolas Capens84c9c452022-11-18 14:11:05 +00001159 result = (cmp_opcode == spv::Op::OpFOrdLessThanEqual ||
1160 cmp_opcode == spv::Op::OpFUnordLessThanEqual);
Chris Forbescc5697f2019-01-30 11:54:08 -08001161 }
1162 }
1163
1164 if (min_const) {
1165 if (min_const->GetValueAsDouble() >
1166 constants[1]->GetValueAsDouble()) {
1167 found_result = true;
Nicolas Capens84c9c452022-11-18 14:11:05 +00001168 result = !(cmp_opcode == spv::Op::OpFOrdLessThanEqual ||
1169 cmp_opcode == spv::Op::OpFUnordLessThanEqual);
Chris Forbescc5697f2019-01-30 11:54:08 -08001170 }
1171 }
1172 }
1173 break;
1174 default:
1175 return nullptr;
1176 }
1177
1178 if (!found_result) {
1179 return nullptr;
1180 }
1181
1182 const analysis::Type* bool_type =
1183 context->get_type_mgr()->GetType(inst->type_id());
1184 const analysis::Constant* result_const =
1185 const_mgr->GetConstant(bool_type, {static_cast<uint32_t>(result)});
1186 assert(result_const);
1187 return result_const;
1188 };
1189}
1190
Ben Claytond0f684e2019-08-30 22:36:08 +01001191ConstantFoldingRule FoldFMix() {
1192 return [](IRContext* context, Instruction* inst,
1193 const std::vector<const analysis::Constant*>& constants)
1194 -> const analysis::Constant* {
1195 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
Nicolas Capens84c9c452022-11-18 14:11:05 +00001196 assert(inst->opcode() == spv::Op::OpExtInst &&
Ben Claytond0f684e2019-08-30 22:36:08 +01001197 "Expecting an extended instruction.");
1198 assert(inst->GetSingleWordInOperand(0) ==
1199 context->get_feature_mgr()->GetExtInstImportId_GLSLstd450() &&
1200 "Expecting a GLSLstd450 extended instruction.");
1201 assert(inst->GetSingleWordInOperand(1) == GLSLstd450FMix &&
1202 "Expecting and FMix instruction.");
1203
1204 if (!inst->IsFloatingPointFoldingAllowed()) {
1205 return nullptr;
1206 }
1207
1208 // Make sure all FMix operands are constants.
1209 for (uint32_t i = 1; i < 4; i++) {
1210 if (constants[i] == nullptr) {
1211 return nullptr;
1212 }
1213 }
1214
1215 const analysis::Constant* one;
Ben Claytond552f632019-11-18 11:18:41 +00001216 bool is_vector = false;
1217 const analysis::Type* result_type = constants[1]->type();
1218 const analysis::Type* base_type = result_type;
1219 if (base_type->AsVector()) {
1220 is_vector = true;
1221 base_type = base_type->AsVector()->element_type();
1222 }
1223 assert(base_type->AsFloat() != nullptr &&
1224 "FMix is suppose to act on floats or vectors of floats.");
1225
1226 if (base_type->AsFloat()->width() == 32) {
1227 one = const_mgr->GetConstant(base_type,
Ben Claytond0f684e2019-08-30 22:36:08 +01001228 utils::FloatProxy<float>(1.0f).GetWords());
1229 } else {
Ben Claytond552f632019-11-18 11:18:41 +00001230 one = const_mgr->GetConstant(base_type,
Ben Claytond0f684e2019-08-30 22:36:08 +01001231 utils::FloatProxy<double>(1.0).GetWords());
1232 }
1233
Ben Claytond552f632019-11-18 11:18:41 +00001234 if (is_vector) {
1235 uint32_t one_id = const_mgr->GetDefiningInstruction(one)->result_id();
1236 one =
1237 const_mgr->GetConstant(result_type, std::vector<uint32_t>(4, one_id));
1238 }
1239
1240 const analysis::Constant* temp1 = FoldFPBinaryOp(
1241 FOLD_FPARITH_OP(-), inst->type_id(), {one, constants[3]}, context);
Ben Claytond0f684e2019-08-30 22:36:08 +01001242 if (temp1 == nullptr) {
1243 return nullptr;
1244 }
1245
Ben Claytond552f632019-11-18 11:18:41 +00001246 const analysis::Constant* temp2 = FoldFPBinaryOp(
1247 FOLD_FPARITH_OP(*), inst->type_id(), {constants[1], temp1}, context);
Ben Claytond0f684e2019-08-30 22:36:08 +01001248 if (temp2 == nullptr) {
1249 return nullptr;
1250 }
Ben Claytond552f632019-11-18 11:18:41 +00001251 const analysis::Constant* temp3 =
1252 FoldFPBinaryOp(FOLD_FPARITH_OP(*), inst->type_id(),
1253 {constants[2], constants[3]}, context);
Ben Claytond0f684e2019-08-30 22:36:08 +01001254 if (temp3 == nullptr) {
1255 return nullptr;
1256 }
Ben Claytond552f632019-11-18 11:18:41 +00001257 return FoldFPBinaryOp(FOLD_FPARITH_OP(+), inst->type_id(), {temp2, temp3},
1258 context);
Ben Claytond0f684e2019-08-30 22:36:08 +01001259 };
1260}
1261
Ben Claytond552f632019-11-18 11:18:41 +00001262const analysis::Constant* FoldMin(const analysis::Type* result_type,
1263 const analysis::Constant* a,
1264 const analysis::Constant* b,
1265 analysis::ConstantManager*) {
1266 if (const analysis::Integer* int_type = result_type->AsInteger()) {
1267 if (int_type->width() == 32) {
1268 if (int_type->IsSigned()) {
1269 int32_t va = a->GetS32();
1270 int32_t vb = b->GetS32();
1271 return (va < vb ? a : b);
1272 } else {
1273 uint32_t va = a->GetU32();
1274 uint32_t vb = b->GetU32();
1275 return (va < vb ? a : b);
1276 }
1277 } else if (int_type->width() == 64) {
1278 if (int_type->IsSigned()) {
1279 int64_t va = a->GetS64();
1280 int64_t vb = b->GetS64();
1281 return (va < vb ? a : b);
1282 } else {
1283 uint64_t va = a->GetU64();
1284 uint64_t vb = b->GetU64();
1285 return (va < vb ? a : b);
1286 }
1287 }
1288 } else if (const analysis::Float* float_type = result_type->AsFloat()) {
1289 if (float_type->width() == 32) {
1290 float va = a->GetFloat();
1291 float vb = b->GetFloat();
1292 return (va < vb ? a : b);
1293 } else if (float_type->width() == 64) {
1294 double va = a->GetDouble();
1295 double vb = b->GetDouble();
1296 return (va < vb ? a : b);
1297 }
1298 }
1299 return nullptr;
1300}
1301
1302const analysis::Constant* FoldMax(const analysis::Type* result_type,
1303 const analysis::Constant* a,
1304 const analysis::Constant* b,
1305 analysis::ConstantManager*) {
1306 if (const analysis::Integer* int_type = result_type->AsInteger()) {
1307 if (int_type->width() == 32) {
1308 if (int_type->IsSigned()) {
1309 int32_t va = a->GetS32();
1310 int32_t vb = b->GetS32();
1311 return (va > vb ? a : b);
1312 } else {
1313 uint32_t va = a->GetU32();
1314 uint32_t vb = b->GetU32();
1315 return (va > vb ? a : b);
1316 }
1317 } else if (int_type->width() == 64) {
1318 if (int_type->IsSigned()) {
1319 int64_t va = a->GetS64();
1320 int64_t vb = b->GetS64();
1321 return (va > vb ? a : b);
1322 } else {
1323 uint64_t va = a->GetU64();
1324 uint64_t vb = b->GetU64();
1325 return (va > vb ? a : b);
1326 }
1327 }
1328 } else if (const analysis::Float* float_type = result_type->AsFloat()) {
1329 if (float_type->width() == 32) {
1330 float va = a->GetFloat();
1331 float vb = b->GetFloat();
1332 return (va > vb ? a : b);
1333 } else if (float_type->width() == 64) {
1334 double va = a->GetDouble();
1335 double vb = b->GetDouble();
1336 return (va > vb ? a : b);
1337 }
1338 }
1339 return nullptr;
1340}
1341
1342// Fold an clamp instruction when all three operands are constant.
1343const analysis::Constant* FoldClamp1(
1344 IRContext* context, Instruction* inst,
1345 const std::vector<const analysis::Constant*>& constants) {
Nicolas Capens84c9c452022-11-18 14:11:05 +00001346 assert(inst->opcode() == spv::Op::OpExtInst &&
Ben Claytond552f632019-11-18 11:18:41 +00001347 "Expecting an extended instruction.");
1348 assert(inst->GetSingleWordInOperand(0) ==
1349 context->get_feature_mgr()->GetExtInstImportId_GLSLstd450() &&
1350 "Expecting a GLSLstd450 extended instruction.");
1351
1352 // Make sure all Clamp operands are constants.
Alexis Hetu00e0af12021-11-08 08:57:46 -05001353 for (uint32_t i = 1; i < 4; i++) {
Ben Claytond552f632019-11-18 11:18:41 +00001354 if (constants[i] == nullptr) {
1355 return nullptr;
1356 }
1357 }
1358
1359 const analysis::Constant* temp = FoldFPBinaryOp(
1360 FoldMax, inst->type_id(), {constants[1], constants[2]}, context);
1361 if (temp == nullptr) {
1362 return nullptr;
1363 }
1364 return FoldFPBinaryOp(FoldMin, inst->type_id(), {temp, constants[3]},
1365 context);
1366}
1367
Alexis Hetu00e0af12021-11-08 08:57:46 -05001368// Fold a clamp instruction when |x <= min_val|.
Ben Claytond552f632019-11-18 11:18:41 +00001369const analysis::Constant* FoldClamp2(
1370 IRContext* context, Instruction* inst,
1371 const std::vector<const analysis::Constant*>& constants) {
Nicolas Capens84c9c452022-11-18 14:11:05 +00001372 assert(inst->opcode() == spv::Op::OpExtInst &&
Ben Claytond552f632019-11-18 11:18:41 +00001373 "Expecting an extended instruction.");
1374 assert(inst->GetSingleWordInOperand(0) ==
1375 context->get_feature_mgr()->GetExtInstImportId_GLSLstd450() &&
1376 "Expecting a GLSLstd450 extended instruction.");
1377
1378 const analysis::Constant* x = constants[1];
1379 const analysis::Constant* min_val = constants[2];
1380
1381 if (x == nullptr || min_val == nullptr) {
1382 return nullptr;
1383 }
1384
1385 const analysis::Constant* temp =
1386 FoldFPBinaryOp(FoldMax, inst->type_id(), {x, min_val}, context);
1387 if (temp == min_val) {
1388 // We can assume that |min_val| is less than |max_val|. Therefore, if the
1389 // result of the max operation is |min_val|, we know the result of the min
1390 // operation, even if |max_val| is not a constant.
1391 return min_val;
1392 }
1393 return nullptr;
1394}
1395
1396// Fold a clamp instruction when |x >= max_val|.
1397const analysis::Constant* FoldClamp3(
1398 IRContext* context, Instruction* inst,
1399 const std::vector<const analysis::Constant*>& constants) {
Nicolas Capens84c9c452022-11-18 14:11:05 +00001400 assert(inst->opcode() == spv::Op::OpExtInst &&
Ben Claytond552f632019-11-18 11:18:41 +00001401 "Expecting an extended instruction.");
1402 assert(inst->GetSingleWordInOperand(0) ==
1403 context->get_feature_mgr()->GetExtInstImportId_GLSLstd450() &&
1404 "Expecting a GLSLstd450 extended instruction.");
1405
1406 const analysis::Constant* x = constants[1];
1407 const analysis::Constant* max_val = constants[3];
1408
1409 if (x == nullptr || max_val == nullptr) {
1410 return nullptr;
1411 }
1412
1413 const analysis::Constant* temp =
1414 FoldFPBinaryOp(FoldMin, inst->type_id(), {x, max_val}, context);
1415 if (temp == max_val) {
1416 // We can assume that |min_val| is less than |max_val|. Therefore, if the
1417 // result of the max operation is |min_val|, we know the result of the min
1418 // operation, even if |max_val| is not a constant.
1419 return max_val;
1420 }
1421 return nullptr;
1422}
1423
Ben Claytondc6b76a2020-02-24 14:53:40 +00001424UnaryScalarFoldingRule FoldFTranscendentalUnary(double (*fp)(double)) {
1425 return
1426 [fp](const analysis::Type* result_type, const analysis::Constant* a,
1427 analysis::ConstantManager* const_mgr) -> const analysis::Constant* {
1428 assert(result_type != nullptr && a != nullptr);
1429 const analysis::Float* float_type = a->type()->AsFloat();
1430 assert(float_type != nullptr);
1431 assert(float_type == result_type->AsFloat());
1432 if (float_type->width() == 32) {
1433 float fa = a->GetFloat();
1434 float res = static_cast<float>(fp(fa));
1435 utils::FloatProxy<float> result(res);
1436 std::vector<uint32_t> words = result.GetWords();
1437 return const_mgr->GetConstant(result_type, words);
1438 } else if (float_type->width() == 64) {
1439 double fa = a->GetDouble();
1440 double res = fp(fa);
1441 utils::FloatProxy<double> result(res);
1442 std::vector<uint32_t> words = result.GetWords();
1443 return const_mgr->GetConstant(result_type, words);
1444 }
1445 return nullptr;
1446 };
1447}
1448
1449BinaryScalarFoldingRule FoldFTranscendentalBinary(double (*fp)(double,
1450 double)) {
1451 return
1452 [fp](const analysis::Type* result_type, const analysis::Constant* a,
1453 const analysis::Constant* b,
1454 analysis::ConstantManager* const_mgr) -> const analysis::Constant* {
1455 assert(result_type != nullptr && a != nullptr);
1456 const analysis::Float* float_type = a->type()->AsFloat();
1457 assert(float_type != nullptr);
1458 assert(float_type == result_type->AsFloat());
1459 assert(float_type == b->type()->AsFloat());
1460 if (float_type->width() == 32) {
1461 float fa = a->GetFloat();
1462 float fb = b->GetFloat();
1463 float res = static_cast<float>(fp(fa, fb));
1464 utils::FloatProxy<float> result(res);
1465 std::vector<uint32_t> words = result.GetWords();
1466 return const_mgr->GetConstant(result_type, words);
1467 } else if (float_type->width() == 64) {
1468 double fa = a->GetDouble();
1469 double fb = b->GetDouble();
1470 double res = fp(fa, fb);
1471 utils::FloatProxy<double> result(res);
1472 std::vector<uint32_t> words = result.GetWords();
1473 return const_mgr->GetConstant(result_type, words);
1474 }
1475 return nullptr;
1476 };
1477}
Chris Forbescc5697f2019-01-30 11:54:08 -08001478} // namespace
1479
Ben Claytond0f684e2019-08-30 22:36:08 +01001480void ConstantFoldingRules::AddFoldingRules() {
Chris Forbescc5697f2019-01-30 11:54:08 -08001481 // Add all folding rules to the list for the opcodes to which they apply.
1482 // Note that the order in which rules are added to the list matters. If a rule
1483 // applies to the instruction, the rest of the rules will not be attempted.
1484 // Take that into consideration.
1485
Nicolas Capens84c9c452022-11-18 14:11:05 +00001486 rules_[spv::Op::OpCompositeConstruct].push_back(FoldCompositeWithConstants());
Chris Forbescc5697f2019-01-30 11:54:08 -08001487
Nicolas Capens84c9c452022-11-18 14:11:05 +00001488 rules_[spv::Op::OpCompositeExtract].push_back(FoldExtractWithConstants());
1489 rules_[spv::Op::OpCompositeInsert].push_back(FoldInsertWithConstants());
Chris Forbescc5697f2019-01-30 11:54:08 -08001490
Nicolas Capens84c9c452022-11-18 14:11:05 +00001491 rules_[spv::Op::OpConvertFToS].push_back(FoldFToI());
1492 rules_[spv::Op::OpConvertFToU].push_back(FoldFToI());
1493 rules_[spv::Op::OpConvertSToF].push_back(FoldIToF());
1494 rules_[spv::Op::OpConvertUToF].push_back(FoldIToF());
Chris Forbescc5697f2019-01-30 11:54:08 -08001495
Nicolas Capens84c9c452022-11-18 14:11:05 +00001496 rules_[spv::Op::OpDot].push_back(FoldOpDotWithConstants());
1497 rules_[spv::Op::OpFAdd].push_back(FoldFAdd());
1498 rules_[spv::Op::OpFDiv].push_back(FoldFDiv());
1499 rules_[spv::Op::OpFMul].push_back(FoldFMul());
1500 rules_[spv::Op::OpFSub].push_back(FoldFSub());
Chris Forbescc5697f2019-01-30 11:54:08 -08001501
Nicolas Capens84c9c452022-11-18 14:11:05 +00001502 rules_[spv::Op::OpFOrdEqual].push_back(FoldFOrdEqual());
Chris Forbescc5697f2019-01-30 11:54:08 -08001503
Nicolas Capens84c9c452022-11-18 14:11:05 +00001504 rules_[spv::Op::OpFUnordEqual].push_back(FoldFUnordEqual());
Chris Forbescc5697f2019-01-30 11:54:08 -08001505
Nicolas Capens84c9c452022-11-18 14:11:05 +00001506 rules_[spv::Op::OpFOrdNotEqual].push_back(FoldFOrdNotEqual());
Chris Forbescc5697f2019-01-30 11:54:08 -08001507
Nicolas Capens84c9c452022-11-18 14:11:05 +00001508 rules_[spv::Op::OpFUnordNotEqual].push_back(FoldFUnordNotEqual());
Chris Forbescc5697f2019-01-30 11:54:08 -08001509
Nicolas Capens84c9c452022-11-18 14:11:05 +00001510 rules_[spv::Op::OpFOrdLessThan].push_back(FoldFOrdLessThan());
1511 rules_[spv::Op::OpFOrdLessThan].push_back(
1512 FoldFClampFeedingCompare(spv::Op::OpFOrdLessThan));
Chris Forbescc5697f2019-01-30 11:54:08 -08001513
Nicolas Capens84c9c452022-11-18 14:11:05 +00001514 rules_[spv::Op::OpFUnordLessThan].push_back(FoldFUnordLessThan());
1515 rules_[spv::Op::OpFUnordLessThan].push_back(
1516 FoldFClampFeedingCompare(spv::Op::OpFUnordLessThan));
Chris Forbescc5697f2019-01-30 11:54:08 -08001517
Nicolas Capens84c9c452022-11-18 14:11:05 +00001518 rules_[spv::Op::OpFOrdGreaterThan].push_back(FoldFOrdGreaterThan());
1519 rules_[spv::Op::OpFOrdGreaterThan].push_back(
1520 FoldFClampFeedingCompare(spv::Op::OpFOrdGreaterThan));
Chris Forbescc5697f2019-01-30 11:54:08 -08001521
Nicolas Capens84c9c452022-11-18 14:11:05 +00001522 rules_[spv::Op::OpFUnordGreaterThan].push_back(FoldFUnordGreaterThan());
1523 rules_[spv::Op::OpFUnordGreaterThan].push_back(
1524 FoldFClampFeedingCompare(spv::Op::OpFUnordGreaterThan));
Chris Forbescc5697f2019-01-30 11:54:08 -08001525
Nicolas Capens84c9c452022-11-18 14:11:05 +00001526 rules_[spv::Op::OpFOrdLessThanEqual].push_back(FoldFOrdLessThanEqual());
1527 rules_[spv::Op::OpFOrdLessThanEqual].push_back(
1528 FoldFClampFeedingCompare(spv::Op::OpFOrdLessThanEqual));
Chris Forbescc5697f2019-01-30 11:54:08 -08001529
Nicolas Capens84c9c452022-11-18 14:11:05 +00001530 rules_[spv::Op::OpFUnordLessThanEqual].push_back(FoldFUnordLessThanEqual());
1531 rules_[spv::Op::OpFUnordLessThanEqual].push_back(
1532 FoldFClampFeedingCompare(spv::Op::OpFUnordLessThanEqual));
Chris Forbescc5697f2019-01-30 11:54:08 -08001533
Nicolas Capens84c9c452022-11-18 14:11:05 +00001534 rules_[spv::Op::OpFOrdGreaterThanEqual].push_back(FoldFOrdGreaterThanEqual());
1535 rules_[spv::Op::OpFOrdGreaterThanEqual].push_back(
1536 FoldFClampFeedingCompare(spv::Op::OpFOrdGreaterThanEqual));
Chris Forbescc5697f2019-01-30 11:54:08 -08001537
Nicolas Capens84c9c452022-11-18 14:11:05 +00001538 rules_[spv::Op::OpFUnordGreaterThanEqual].push_back(
1539 FoldFUnordGreaterThanEqual());
1540 rules_[spv::Op::OpFUnordGreaterThanEqual].push_back(
1541 FoldFClampFeedingCompare(spv::Op::OpFUnordGreaterThanEqual));
Chris Forbescc5697f2019-01-30 11:54:08 -08001542
Nicolas Capens84c9c452022-11-18 14:11:05 +00001543 rules_[spv::Op::OpVectorShuffle].push_back(FoldVectorShuffleWithConstants());
1544 rules_[spv::Op::OpVectorTimesScalar].push_back(FoldVectorTimesScalar());
1545 rules_[spv::Op::OpVectorTimesMatrix].push_back(FoldVectorTimesMatrix());
1546 rules_[spv::Op::OpMatrixTimesVector].push_back(FoldMatrixTimesVector());
Chris Forbescc5697f2019-01-30 11:54:08 -08001547
Nicolas Capens84c9c452022-11-18 14:11:05 +00001548 rules_[spv::Op::OpFNegate].push_back(FoldFNegate());
1549 rules_[spv::Op::OpQuantizeToF16].push_back(FoldQuantizeToF16());
Ben Claytond0f684e2019-08-30 22:36:08 +01001550
1551 // Add rules for GLSLstd450
1552 FeatureManager* feature_manager = context_->get_feature_mgr();
1553 uint32_t ext_inst_glslstd450_id =
1554 feature_manager->GetExtInstImportId_GLSLstd450();
1555 if (ext_inst_glslstd450_id != 0) {
1556 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450FMix}].push_back(FoldFMix());
Ben Claytond552f632019-11-18 11:18:41 +00001557 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450SMin}].push_back(
1558 FoldFPBinaryOp(FoldMin));
1559 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450UMin}].push_back(
1560 FoldFPBinaryOp(FoldMin));
1561 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450FMin}].push_back(
1562 FoldFPBinaryOp(FoldMin));
1563 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450SMax}].push_back(
1564 FoldFPBinaryOp(FoldMax));
1565 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450UMax}].push_back(
1566 FoldFPBinaryOp(FoldMax));
1567 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450FMax}].push_back(
1568 FoldFPBinaryOp(FoldMax));
1569 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450UClamp}].push_back(
1570 FoldClamp1);
1571 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450UClamp}].push_back(
1572 FoldClamp2);
1573 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450UClamp}].push_back(
1574 FoldClamp3);
1575 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450SClamp}].push_back(
1576 FoldClamp1);
1577 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450SClamp}].push_back(
1578 FoldClamp2);
1579 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450SClamp}].push_back(
1580 FoldClamp3);
1581 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450FClamp}].push_back(
1582 FoldClamp1);
1583 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450FClamp}].push_back(
1584 FoldClamp2);
1585 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450FClamp}].push_back(
1586 FoldClamp3);
Ben Claytondc6b76a2020-02-24 14:53:40 +00001587 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Sin}].push_back(
1588 FoldFPUnaryOp(FoldFTranscendentalUnary(std::sin)));
1589 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Cos}].push_back(
1590 FoldFPUnaryOp(FoldFTranscendentalUnary(std::cos)));
1591 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Tan}].push_back(
1592 FoldFPUnaryOp(FoldFTranscendentalUnary(std::tan)));
1593 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Asin}].push_back(
1594 FoldFPUnaryOp(FoldFTranscendentalUnary(std::asin)));
1595 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Acos}].push_back(
1596 FoldFPUnaryOp(FoldFTranscendentalUnary(std::acos)));
1597 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Atan}].push_back(
1598 FoldFPUnaryOp(FoldFTranscendentalUnary(std::atan)));
1599 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Exp}].push_back(
1600 FoldFPUnaryOp(FoldFTranscendentalUnary(std::exp)));
1601 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Log}].push_back(
1602 FoldFPUnaryOp(FoldFTranscendentalUnary(std::log)));
1603
1604#ifdef __ANDROID__
sugoi1b398bf32022-02-18 10:27:28 -05001605 // Android NDK r15c targeting ABI 15 doesn't have full support for C++11
Ben Claytondc6b76a2020-02-24 14:53:40 +00001606 // (no std::exp2/log2). ::exp2 is available from C99 but ::log2 isn't
1607 // available up until ABI 18 so we use a shim
1608 auto log2_shim = [](double v) -> double { return log(v) / log(2.0); };
1609 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Exp2}].push_back(
1610 FoldFPUnaryOp(FoldFTranscendentalUnary(::exp2)));
1611 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Log2}].push_back(
1612 FoldFPUnaryOp(FoldFTranscendentalUnary(log2_shim)));
1613#else
1614 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Exp2}].push_back(
1615 FoldFPUnaryOp(FoldFTranscendentalUnary(std::exp2)));
1616 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Log2}].push_back(
1617 FoldFPUnaryOp(FoldFTranscendentalUnary(std::log2)));
1618#endif
1619
1620 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Sqrt}].push_back(
1621 FoldFPUnaryOp(FoldFTranscendentalUnary(std::sqrt)));
1622 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Atan2}].push_back(
1623 FoldFPBinaryOp(FoldFTranscendentalBinary(std::atan2)));
1624 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Pow}].push_back(
1625 FoldFPBinaryOp(FoldFTranscendentalBinary(std::pow)));
Ben Claytond0f684e2019-08-30 22:36:08 +01001626 }
Chris Forbescc5697f2019-01-30 11:54:08 -08001627}
1628} // namespace opt
1629} // namespace spvtools