blob: 2a2493fd6fc0786e17d5575442e9b41357dd19c1 [file] [log] [blame]
Chris Forbescc5697f2019-01-30 11:54:08 -08001// Copyright (c) 2018 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#include "source/opt/const_folding_rules.h"
16
17#include "source/opt/ir_context.h"
18
19namespace spvtools {
20namespace opt {
21namespace {
22
23const uint32_t kExtractCompositeIdInIdx = 0;
24
25// Returns true if |type| is Float or a vector of Float.
26bool HasFloatingPoint(const analysis::Type* type) {
27 if (type->AsFloat()) {
28 return true;
29 } else if (const analysis::Vector* vec_type = type->AsVector()) {
30 return vec_type->element_type()->AsFloat() != nullptr;
31 }
32
33 return false;
34}
35
36// Folds an OpcompositeExtract where input is a composite constant.
37ConstantFoldingRule FoldExtractWithConstants() {
38 return [](IRContext* context, Instruction* inst,
39 const std::vector<const analysis::Constant*>& constants)
40 -> const analysis::Constant* {
41 const analysis::Constant* c = constants[kExtractCompositeIdInIdx];
42 if (c == nullptr) {
43 return nullptr;
44 }
45
46 for (uint32_t i = 1; i < inst->NumInOperands(); ++i) {
47 uint32_t element_index = inst->GetSingleWordInOperand(i);
48 if (c->AsNullConstant()) {
49 // Return Null for the return type.
50 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
51 analysis::TypeManager* type_mgr = context->get_type_mgr();
52 return const_mgr->GetConstant(type_mgr->GetType(inst->type_id()), {});
53 }
54
55 auto cc = c->AsCompositeConstant();
56 assert(cc != nullptr);
57 auto components = cc->GetComponents();
Ben Claytond0f684e2019-08-30 22:36:08 +010058 // Protect against invalid IR. Refuse to fold if the index is out
59 // of bounds.
60 if (element_index >= components.size()) return nullptr;
Chris Forbescc5697f2019-01-30 11:54:08 -080061 c = components[element_index];
62 }
63 return c;
64 };
65}
66
67ConstantFoldingRule FoldVectorShuffleWithConstants() {
68 return [](IRContext* context, Instruction* inst,
69 const std::vector<const analysis::Constant*>& constants)
70 -> const analysis::Constant* {
71 assert(inst->opcode() == SpvOpVectorShuffle);
72 const analysis::Constant* c1 = constants[0];
73 const analysis::Constant* c2 = constants[1];
74 if (c1 == nullptr || c2 == nullptr) {
75 return nullptr;
76 }
77
78 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
79 const analysis::Type* element_type = c1->type()->AsVector()->element_type();
80
81 std::vector<const analysis::Constant*> c1_components;
82 if (const analysis::VectorConstant* vec_const = c1->AsVectorConstant()) {
83 c1_components = vec_const->GetComponents();
84 } else {
85 assert(c1->AsNullConstant());
86 const analysis::Constant* element =
87 const_mgr->GetConstant(element_type, {});
88 c1_components.resize(c1->type()->AsVector()->element_count(), element);
89 }
90 std::vector<const analysis::Constant*> c2_components;
91 if (const analysis::VectorConstant* vec_const = c2->AsVectorConstant()) {
92 c2_components = vec_const->GetComponents();
93 } else {
94 assert(c2->AsNullConstant());
95 const analysis::Constant* element =
96 const_mgr->GetConstant(element_type, {});
97 c2_components.resize(c2->type()->AsVector()->element_count(), element);
98 }
99
100 std::vector<uint32_t> ids;
101 const uint32_t undef_literal_value = 0xffffffff;
102 for (uint32_t i = 2; i < inst->NumInOperands(); ++i) {
103 uint32_t index = inst->GetSingleWordInOperand(i);
104 if (index == undef_literal_value) {
105 // Don't fold shuffle with undef literal value.
106 return nullptr;
107 } else if (index < c1_components.size()) {
108 Instruction* member_inst =
109 const_mgr->GetDefiningInstruction(c1_components[index]);
110 ids.push_back(member_inst->result_id());
111 } else {
112 Instruction* member_inst = const_mgr->GetDefiningInstruction(
113 c2_components[index - c1_components.size()]);
114 ids.push_back(member_inst->result_id());
115 }
116 }
117
118 analysis::TypeManager* type_mgr = context->get_type_mgr();
119 return const_mgr->GetConstant(type_mgr->GetType(inst->type_id()), ids);
120 };
121}
122
123ConstantFoldingRule FoldVectorTimesScalar() {
124 return [](IRContext* context, Instruction* inst,
125 const std::vector<const analysis::Constant*>& constants)
126 -> const analysis::Constant* {
127 assert(inst->opcode() == SpvOpVectorTimesScalar);
128 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
129 analysis::TypeManager* type_mgr = context->get_type_mgr();
130
131 if (!inst->IsFloatingPointFoldingAllowed()) {
132 if (HasFloatingPoint(type_mgr->GetType(inst->type_id()))) {
133 return nullptr;
134 }
135 }
136
137 const analysis::Constant* c1 = constants[0];
138 const analysis::Constant* c2 = constants[1];
139
140 if (c1 && c1->IsZero()) {
141 return c1;
142 }
143
144 if (c2 && c2->IsZero()) {
145 // Get or create the NullConstant for this type.
146 std::vector<uint32_t> ids;
147 return const_mgr->GetConstant(type_mgr->GetType(inst->type_id()), ids);
148 }
149
150 if (c1 == nullptr || c2 == nullptr) {
151 return nullptr;
152 }
153
154 // Check result type.
155 const analysis::Type* result_type = type_mgr->GetType(inst->type_id());
156 const analysis::Vector* vector_type = result_type->AsVector();
157 assert(vector_type != nullptr);
158 const analysis::Type* element_type = vector_type->element_type();
159 assert(element_type != nullptr);
160 const analysis::Float* float_type = element_type->AsFloat();
161 assert(float_type != nullptr);
162
163 // Check types of c1 and c2.
164 assert(c1->type()->AsVector() == vector_type);
165 assert(c1->type()->AsVector()->element_type() == element_type &&
166 c2->type() == element_type);
167
168 // Get a float vector that is the result of vector-times-scalar.
169 std::vector<const analysis::Constant*> c1_components =
170 c1->GetVectorComponents(const_mgr);
171 std::vector<uint32_t> ids;
172 if (float_type->width() == 32) {
173 float scalar = c2->GetFloat();
174 for (uint32_t i = 0; i < c1_components.size(); ++i) {
175 utils::FloatProxy<float> result(c1_components[i]->GetFloat() * scalar);
176 std::vector<uint32_t> words = result.GetWords();
177 const analysis::Constant* new_elem =
178 const_mgr->GetConstant(float_type, words);
179 ids.push_back(const_mgr->GetDefiningInstruction(new_elem)->result_id());
180 }
181 return const_mgr->GetConstant(vector_type, ids);
182 } else if (float_type->width() == 64) {
183 double scalar = c2->GetDouble();
184 for (uint32_t i = 0; i < c1_components.size(); ++i) {
185 utils::FloatProxy<double> result(c1_components[i]->GetDouble() *
186 scalar);
187 std::vector<uint32_t> words = result.GetWords();
188 const analysis::Constant* new_elem =
189 const_mgr->GetConstant(float_type, words);
190 ids.push_back(const_mgr->GetDefiningInstruction(new_elem)->result_id());
191 }
192 return const_mgr->GetConstant(vector_type, ids);
193 }
194 return nullptr;
195 };
196}
197
198ConstantFoldingRule FoldCompositeWithConstants() {
199 // Folds an OpCompositeConstruct where all of the inputs are constants to a
200 // constant. A new constant is created if necessary.
201 return [](IRContext* context, Instruction* inst,
202 const std::vector<const analysis::Constant*>& constants)
203 -> const analysis::Constant* {
204 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
205 analysis::TypeManager* type_mgr = context->get_type_mgr();
206 const analysis::Type* new_type = type_mgr->GetType(inst->type_id());
207 Instruction* type_inst =
208 context->get_def_use_mgr()->GetDef(inst->type_id());
209
210 std::vector<uint32_t> ids;
211 for (uint32_t i = 0; i < constants.size(); ++i) {
212 const analysis::Constant* element_const = constants[i];
213 if (element_const == nullptr) {
214 return nullptr;
215 }
216
217 uint32_t component_type_id = 0;
218 if (type_inst->opcode() == SpvOpTypeStruct) {
219 component_type_id = type_inst->GetSingleWordInOperand(i);
220 } else if (type_inst->opcode() == SpvOpTypeArray) {
221 component_type_id = type_inst->GetSingleWordInOperand(0);
222 }
223
224 uint32_t element_id =
225 const_mgr->FindDeclaredConstant(element_const, component_type_id);
226 if (element_id == 0) {
227 return nullptr;
228 }
229 ids.push_back(element_id);
230 }
231 return const_mgr->GetConstant(new_type, ids);
232 };
233}
234
235// The interface for a function that returns the result of applying a scalar
236// floating-point binary operation on |a| and |b|. The type of the return value
237// will be |type|. The input constants must also be of type |type|.
238using UnaryScalarFoldingRule = std::function<const analysis::Constant*(
239 const analysis::Type* result_type, const analysis::Constant* a,
240 analysis::ConstantManager*)>;
241
242// The interface for a function that returns the result of applying a scalar
243// floating-point binary operation on |a| and |b|. The type of the return value
244// will be |type|. The input constants must also be of type |type|.
245using BinaryScalarFoldingRule = std::function<const analysis::Constant*(
246 const analysis::Type* result_type, const analysis::Constant* a,
247 const analysis::Constant* b, analysis::ConstantManager*)>;
248
249// Returns a |ConstantFoldingRule| that folds unary floating point scalar ops
250// using |scalar_rule| and unary float point vectors ops by applying
251// |scalar_rule| to the elements of the vector. The |ConstantFoldingRule|
252// that is returned assumes that |constants| contains 1 entry. If they are
253// not |nullptr|, then their type is either |Float| or |Integer| or a |Vector|
254// whose element type is |Float| or |Integer|.
255ConstantFoldingRule FoldFPUnaryOp(UnaryScalarFoldingRule scalar_rule) {
256 return [scalar_rule](IRContext* context, Instruction* inst,
257 const std::vector<const analysis::Constant*>& constants)
258 -> const analysis::Constant* {
259 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
260 analysis::TypeManager* type_mgr = context->get_type_mgr();
261 const analysis::Type* result_type = type_mgr->GetType(inst->type_id());
262 const analysis::Vector* vector_type = result_type->AsVector();
263
264 if (!inst->IsFloatingPointFoldingAllowed()) {
265 return nullptr;
266 }
267
268 if (constants[0] == nullptr) {
269 return nullptr;
270 }
271
272 if (vector_type != nullptr) {
273 std::vector<const analysis::Constant*> a_components;
274 std::vector<const analysis::Constant*> results_components;
275
276 a_components = constants[0]->GetVectorComponents(const_mgr);
277
278 // Fold each component of the vector.
279 for (uint32_t i = 0; i < a_components.size(); ++i) {
280 results_components.push_back(scalar_rule(vector_type->element_type(),
281 a_components[i], const_mgr));
282 if (results_components[i] == nullptr) {
283 return nullptr;
284 }
285 }
286
287 // Build the constant object and return it.
288 std::vector<uint32_t> ids;
289 for (const analysis::Constant* member : results_components) {
290 ids.push_back(const_mgr->GetDefiningInstruction(member)->result_id());
291 }
292 return const_mgr->GetConstant(vector_type, ids);
293 } else {
294 return scalar_rule(result_type, constants[0], const_mgr);
295 }
296 };
297}
298
Ben Claytond552f632019-11-18 11:18:41 +0000299// Returns the result of folding the constants in |constants| according the
300// |scalar_rule|. If |result_type| is a vector, then |scalar_rule| is applied
301// per component.
302const analysis::Constant* FoldFPBinaryOp(
303 BinaryScalarFoldingRule scalar_rule, uint32_t result_type_id,
304 const std::vector<const analysis::Constant*>& constants,
305 IRContext* context) {
306 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
307 analysis::TypeManager* type_mgr = context->get_type_mgr();
308 const analysis::Type* result_type = type_mgr->GetType(result_type_id);
309 const analysis::Vector* vector_type = result_type->AsVector();
310
311 if (constants[0] == nullptr || constants[1] == nullptr) {
312 return nullptr;
313 }
314
315 if (vector_type != nullptr) {
316 std::vector<const analysis::Constant*> a_components;
317 std::vector<const analysis::Constant*> b_components;
318 std::vector<const analysis::Constant*> results_components;
319
320 a_components = constants[0]->GetVectorComponents(const_mgr);
321 b_components = constants[1]->GetVectorComponents(const_mgr);
322
323 // Fold each component of the vector.
324 for (uint32_t i = 0; i < a_components.size(); ++i) {
325 results_components.push_back(scalar_rule(vector_type->element_type(),
326 a_components[i], b_components[i],
327 const_mgr));
328 if (results_components[i] == nullptr) {
329 return nullptr;
330 }
331 }
332
333 // Build the constant object and return it.
334 std::vector<uint32_t> ids;
335 for (const analysis::Constant* member : results_components) {
336 ids.push_back(const_mgr->GetDefiningInstruction(member)->result_id());
337 }
338 return const_mgr->GetConstant(vector_type, ids);
339 } else {
340 return scalar_rule(result_type, constants[0], constants[1], const_mgr);
341 }
342}
343
Chris Forbescc5697f2019-01-30 11:54:08 -0800344// Returns a |ConstantFoldingRule| that folds floating point scalars using
345// |scalar_rule| and vectors of floating point by applying |scalar_rule| to the
346// elements of the vector. The |ConstantFoldingRule| that is returned assumes
347// that |constants| contains 2 entries. If they are not |nullptr|, then their
348// type is either |Float| or a |Vector| whose element type is |Float|.
349ConstantFoldingRule FoldFPBinaryOp(BinaryScalarFoldingRule scalar_rule) {
350 return [scalar_rule](IRContext* context, Instruction* inst,
351 const std::vector<const analysis::Constant*>& constants)
352 -> const analysis::Constant* {
Chris Forbescc5697f2019-01-30 11:54:08 -0800353 if (!inst->IsFloatingPointFoldingAllowed()) {
354 return nullptr;
355 }
Ben Claytond552f632019-11-18 11:18:41 +0000356 if (inst->opcode() == SpvOpExtInst) {
357 return FoldFPBinaryOp(scalar_rule, inst->type_id(),
358 {constants[1], constants[2]}, context);
Chris Forbescc5697f2019-01-30 11:54:08 -0800359 }
Ben Claytond552f632019-11-18 11:18:41 +0000360 return FoldFPBinaryOp(scalar_rule, inst->type_id(), constants, context);
Chris Forbescc5697f2019-01-30 11:54:08 -0800361 };
362}
363
364// This macro defines a |UnaryScalarFoldingRule| that performs float to
365// integer conversion.
366// TODO(greg-lunarg): Support for 64-bit integer types.
367UnaryScalarFoldingRule FoldFToIOp() {
368 return [](const analysis::Type* result_type, const analysis::Constant* a,
369 analysis::ConstantManager* const_mgr) -> const analysis::Constant* {
370 assert(result_type != nullptr && a != nullptr);
371 const analysis::Integer* integer_type = result_type->AsInteger();
372 const analysis::Float* float_type = a->type()->AsFloat();
373 assert(float_type != nullptr);
374 assert(integer_type != nullptr);
375 if (integer_type->width() != 32) return nullptr;
376 if (float_type->width() == 32) {
377 float fa = a->GetFloat();
378 uint32_t result = integer_type->IsSigned()
379 ? static_cast<uint32_t>(static_cast<int32_t>(fa))
380 : static_cast<uint32_t>(fa);
381 std::vector<uint32_t> words = {result};
382 return const_mgr->GetConstant(result_type, words);
383 } else if (float_type->width() == 64) {
384 double fa = a->GetDouble();
385 uint32_t result = integer_type->IsSigned()
386 ? static_cast<uint32_t>(static_cast<int32_t>(fa))
387 : static_cast<uint32_t>(fa);
388 std::vector<uint32_t> words = {result};
389 return const_mgr->GetConstant(result_type, words);
390 }
391 return nullptr;
392 };
393}
394
395// This function defines a |UnaryScalarFoldingRule| that performs integer to
396// float conversion.
397// TODO(greg-lunarg): Support for 64-bit integer types.
398UnaryScalarFoldingRule FoldIToFOp() {
399 return [](const analysis::Type* result_type, const analysis::Constant* a,
400 analysis::ConstantManager* const_mgr) -> const analysis::Constant* {
401 assert(result_type != nullptr && a != nullptr);
402 const analysis::Integer* integer_type = a->type()->AsInteger();
403 const analysis::Float* float_type = result_type->AsFloat();
404 assert(float_type != nullptr);
405 assert(integer_type != nullptr);
406 if (integer_type->width() != 32) return nullptr;
407 uint32_t ua = a->GetU32();
408 if (float_type->width() == 32) {
409 float result_val = integer_type->IsSigned()
410 ? static_cast<float>(static_cast<int32_t>(ua))
411 : static_cast<float>(ua);
412 utils::FloatProxy<float> result(result_val);
413 std::vector<uint32_t> words = {result.data()};
414 return const_mgr->GetConstant(result_type, words);
415 } else if (float_type->width() == 64) {
416 double result_val = integer_type->IsSigned()
417 ? static_cast<double>(static_cast<int32_t>(ua))
418 : static_cast<double>(ua);
419 utils::FloatProxy<double> result(result_val);
420 std::vector<uint32_t> words = result.GetWords();
421 return const_mgr->GetConstant(result_type, words);
422 }
423 return nullptr;
424 };
425}
426
Ben Claytonb73b7602019-07-29 13:56:13 +0100427// This defines a |UnaryScalarFoldingRule| that performs |OpQuantizeToF16|.
428UnaryScalarFoldingRule FoldQuantizeToF16Scalar() {
429 return [](const analysis::Type* result_type, const analysis::Constant* a,
430 analysis::ConstantManager* const_mgr) -> const analysis::Constant* {
431 assert(result_type != nullptr && a != nullptr);
432 const analysis::Float* float_type = a->type()->AsFloat();
433 assert(float_type != nullptr);
434 if (float_type->width() != 32) {
435 return nullptr;
436 }
437
438 float fa = a->GetFloat();
439 utils::HexFloat<utils::FloatProxy<float>> orignal(fa);
440 utils::HexFloat<utils::FloatProxy<utils::Float16>> quantized(0);
441 utils::HexFloat<utils::FloatProxy<float>> result(0.0f);
442 orignal.castTo(quantized, utils::round_direction::kToZero);
443 quantized.castTo(result, utils::round_direction::kToZero);
444 std::vector<uint32_t> words = {result.getBits()};
445 return const_mgr->GetConstant(result_type, words);
446 };
447}
448
Chris Forbescc5697f2019-01-30 11:54:08 -0800449// This macro defines a |BinaryScalarFoldingRule| that applies |op|. The
450// operator |op| must work for both float and double, and use syntax "f1 op f2".
Ben Claytond552f632019-11-18 11:18:41 +0000451#define FOLD_FPARITH_OP(op) \
452 [](const analysis::Type* result_type_in_macro, const analysis::Constant* a, \
453 const analysis::Constant* b, \
454 analysis::ConstantManager* const_mgr_in_macro) \
455 -> const analysis::Constant* { \
456 assert(result_type_in_macro != nullptr && a != nullptr && b != nullptr); \
457 assert(result_type_in_macro == a->type() && \
458 result_type_in_macro == b->type()); \
459 const analysis::Float* float_type_in_macro = \
460 result_type_in_macro->AsFloat(); \
461 assert(float_type_in_macro != nullptr); \
462 if (float_type_in_macro->width() == 32) { \
463 float fa = a->GetFloat(); \
464 float fb = b->GetFloat(); \
465 utils::FloatProxy<float> result_in_macro(fa op fb); \
466 std::vector<uint32_t> words_in_macro = result_in_macro.GetWords(); \
467 return const_mgr_in_macro->GetConstant(result_type_in_macro, \
468 words_in_macro); \
469 } else if (float_type_in_macro->width() == 64) { \
470 double fa = a->GetDouble(); \
471 double fb = b->GetDouble(); \
472 utils::FloatProxy<double> result_in_macro(fa op fb); \
473 std::vector<uint32_t> words_in_macro = result_in_macro.GetWords(); \
474 return const_mgr_in_macro->GetConstant(result_type_in_macro, \
475 words_in_macro); \
476 } \
477 return nullptr; \
Chris Forbescc5697f2019-01-30 11:54:08 -0800478 }
479
480// Define the folding rule for conversion between floating point and integer
481ConstantFoldingRule FoldFToI() { return FoldFPUnaryOp(FoldFToIOp()); }
482ConstantFoldingRule FoldIToF() { return FoldFPUnaryOp(FoldIToFOp()); }
Ben Claytonb73b7602019-07-29 13:56:13 +0100483ConstantFoldingRule FoldQuantizeToF16() {
484 return FoldFPUnaryOp(FoldQuantizeToF16Scalar());
485}
Chris Forbescc5697f2019-01-30 11:54:08 -0800486
487// Define the folding rules for subtraction, addition, multiplication, and
488// division for floating point values.
489ConstantFoldingRule FoldFSub() { return FoldFPBinaryOp(FOLD_FPARITH_OP(-)); }
490ConstantFoldingRule FoldFAdd() { return FoldFPBinaryOp(FOLD_FPARITH_OP(+)); }
491ConstantFoldingRule FoldFMul() { return FoldFPBinaryOp(FOLD_FPARITH_OP(*)); }
492ConstantFoldingRule FoldFDiv() { return FoldFPBinaryOp(FOLD_FPARITH_OP(/)); }
493
494bool CompareFloatingPoint(bool op_result, bool op_unordered,
495 bool need_ordered) {
496 if (need_ordered) {
497 // operands are ordered and Operand 1 is |op| Operand 2
498 return !op_unordered && op_result;
499 } else {
500 // operands are unordered or Operand 1 is |op| Operand 2
501 return op_unordered || op_result;
502 }
503}
504
505// This macro defines a |BinaryScalarFoldingRule| that applies |op|. The
506// operator |op| must work for both float and double, and use syntax "f1 op f2".
507#define FOLD_FPCMP_OP(op, ord) \
508 [](const analysis::Type* result_type, const analysis::Constant* a, \
509 const analysis::Constant* b, \
510 analysis::ConstantManager* const_mgr) -> const analysis::Constant* { \
511 assert(result_type != nullptr && a != nullptr && b != nullptr); \
512 assert(result_type->AsBool()); \
513 assert(a->type() == b->type()); \
514 const analysis::Float* float_type = a->type()->AsFloat(); \
515 assert(float_type != nullptr); \
516 if (float_type->width() == 32) { \
517 float fa = a->GetFloat(); \
518 float fb = b->GetFloat(); \
519 bool result = CompareFloatingPoint( \
520 fa op fb, std::isnan(fa) || std::isnan(fb), ord); \
521 std::vector<uint32_t> words = {uint32_t(result)}; \
522 return const_mgr->GetConstant(result_type, words); \
523 } else if (float_type->width() == 64) { \
524 double fa = a->GetDouble(); \
525 double fb = b->GetDouble(); \
526 bool result = CompareFloatingPoint( \
527 fa op fb, std::isnan(fa) || std::isnan(fb), ord); \
528 std::vector<uint32_t> words = {uint32_t(result)}; \
529 return const_mgr->GetConstant(result_type, words); \
530 } \
531 return nullptr; \
532 }
533
534// Define the folding rules for ordered and unordered comparison for floating
535// point values.
536ConstantFoldingRule FoldFOrdEqual() {
537 return FoldFPBinaryOp(FOLD_FPCMP_OP(==, true));
538}
539ConstantFoldingRule FoldFUnordEqual() {
540 return FoldFPBinaryOp(FOLD_FPCMP_OP(==, false));
541}
542ConstantFoldingRule FoldFOrdNotEqual() {
543 return FoldFPBinaryOp(FOLD_FPCMP_OP(!=, true));
544}
545ConstantFoldingRule FoldFUnordNotEqual() {
546 return FoldFPBinaryOp(FOLD_FPCMP_OP(!=, false));
547}
548ConstantFoldingRule FoldFOrdLessThan() {
549 return FoldFPBinaryOp(FOLD_FPCMP_OP(<, true));
550}
551ConstantFoldingRule FoldFUnordLessThan() {
552 return FoldFPBinaryOp(FOLD_FPCMP_OP(<, false));
553}
554ConstantFoldingRule FoldFOrdGreaterThan() {
555 return FoldFPBinaryOp(FOLD_FPCMP_OP(>, true));
556}
557ConstantFoldingRule FoldFUnordGreaterThan() {
558 return FoldFPBinaryOp(FOLD_FPCMP_OP(>, false));
559}
560ConstantFoldingRule FoldFOrdLessThanEqual() {
561 return FoldFPBinaryOp(FOLD_FPCMP_OP(<=, true));
562}
563ConstantFoldingRule FoldFUnordLessThanEqual() {
564 return FoldFPBinaryOp(FOLD_FPCMP_OP(<=, false));
565}
566ConstantFoldingRule FoldFOrdGreaterThanEqual() {
567 return FoldFPBinaryOp(FOLD_FPCMP_OP(>=, true));
568}
569ConstantFoldingRule FoldFUnordGreaterThanEqual() {
570 return FoldFPBinaryOp(FOLD_FPCMP_OP(>=, false));
571}
572
573// Folds an OpDot where all of the inputs are constants to a
574// constant. A new constant is created if necessary.
575ConstantFoldingRule FoldOpDotWithConstants() {
576 return [](IRContext* context, Instruction* inst,
577 const std::vector<const analysis::Constant*>& constants)
578 -> const analysis::Constant* {
579 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
580 analysis::TypeManager* type_mgr = context->get_type_mgr();
581 const analysis::Type* new_type = type_mgr->GetType(inst->type_id());
582 assert(new_type->AsFloat() && "OpDot should have a float return type.");
583 const analysis::Float* float_type = new_type->AsFloat();
584
585 if (!inst->IsFloatingPointFoldingAllowed()) {
586 return nullptr;
587 }
588
589 // If one of the operands is 0, then the result is 0.
590 bool has_zero_operand = false;
591
592 for (int i = 0; i < 2; ++i) {
593 if (constants[i]) {
594 if (constants[i]->AsNullConstant() ||
595 constants[i]->AsVectorConstant()->IsZero()) {
596 has_zero_operand = true;
597 break;
598 }
599 }
600 }
601
602 if (has_zero_operand) {
603 if (float_type->width() == 32) {
604 utils::FloatProxy<float> result(0.0f);
605 std::vector<uint32_t> words = result.GetWords();
606 return const_mgr->GetConstant(float_type, words);
607 }
608 if (float_type->width() == 64) {
609 utils::FloatProxy<double> result(0.0);
610 std::vector<uint32_t> words = result.GetWords();
611 return const_mgr->GetConstant(float_type, words);
612 }
613 return nullptr;
614 }
615
616 if (constants[0] == nullptr || constants[1] == nullptr) {
617 return nullptr;
618 }
619
620 std::vector<const analysis::Constant*> a_components;
621 std::vector<const analysis::Constant*> b_components;
622
623 a_components = constants[0]->GetVectorComponents(const_mgr);
624 b_components = constants[1]->GetVectorComponents(const_mgr);
625
626 utils::FloatProxy<double> result(0.0);
627 std::vector<uint32_t> words = result.GetWords();
628 const analysis::Constant* result_const =
629 const_mgr->GetConstant(float_type, words);
Ben Claytonb73b7602019-07-29 13:56:13 +0100630 for (uint32_t i = 0; i < a_components.size() && result_const != nullptr;
631 ++i) {
Chris Forbescc5697f2019-01-30 11:54:08 -0800632 if (a_components[i] == nullptr || b_components[i] == nullptr) {
633 return nullptr;
634 }
635
636 const analysis::Constant* component = FOLD_FPARITH_OP(*)(
637 new_type, a_components[i], b_components[i], const_mgr);
Ben Claytonb73b7602019-07-29 13:56:13 +0100638 if (component == nullptr) {
639 return nullptr;
640 }
Chris Forbescc5697f2019-01-30 11:54:08 -0800641 result_const =
642 FOLD_FPARITH_OP(+)(new_type, result_const, component, const_mgr);
643 }
644 return result_const;
645 };
646}
647
648// This function defines a |UnaryScalarFoldingRule| that subtracts the constant
649// from zero.
650UnaryScalarFoldingRule FoldFNegateOp() {
651 return [](const analysis::Type* result_type, const analysis::Constant* a,
652 analysis::ConstantManager* const_mgr) -> const analysis::Constant* {
653 assert(result_type != nullptr && a != nullptr);
654 assert(result_type == a->type());
655 const analysis::Float* float_type = result_type->AsFloat();
656 assert(float_type != nullptr);
657 if (float_type->width() == 32) {
658 float fa = a->GetFloat();
659 utils::FloatProxy<float> result(-fa);
660 std::vector<uint32_t> words = result.GetWords();
661 return const_mgr->GetConstant(result_type, words);
662 } else if (float_type->width() == 64) {
663 double da = a->GetDouble();
664 utils::FloatProxy<double> result(-da);
665 std::vector<uint32_t> words = result.GetWords();
666 return const_mgr->GetConstant(result_type, words);
667 }
668 return nullptr;
669 };
670}
671
672ConstantFoldingRule FoldFNegate() { return FoldFPUnaryOp(FoldFNegateOp()); }
673
674ConstantFoldingRule FoldFClampFeedingCompare(uint32_t cmp_opcode) {
675 return [cmp_opcode](IRContext* context, Instruction* inst,
676 const std::vector<const analysis::Constant*>& constants)
677 -> const analysis::Constant* {
678 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
679 analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
680
681 if (!inst->IsFloatingPointFoldingAllowed()) {
682 return nullptr;
683 }
684
685 uint32_t non_const_idx = (constants[0] ? 1 : 0);
686 uint32_t operand_id = inst->GetSingleWordInOperand(non_const_idx);
687 Instruction* operand_inst = def_use_mgr->GetDef(operand_id);
688
689 analysis::TypeManager* type_mgr = context->get_type_mgr();
690 const analysis::Type* operand_type =
691 type_mgr->GetType(operand_inst->type_id());
692
693 if (!operand_type->AsFloat()) {
694 return nullptr;
695 }
696
697 if (operand_type->AsFloat()->width() != 32 &&
698 operand_type->AsFloat()->width() != 64) {
699 return nullptr;
700 }
701
702 if (operand_inst->opcode() != SpvOpExtInst) {
703 return nullptr;
704 }
705
706 if (operand_inst->GetSingleWordInOperand(1) != GLSLstd450FClamp) {
707 return nullptr;
708 }
709
710 if (constants[1] == nullptr && constants[0] == nullptr) {
711 return nullptr;
712 }
713
714 uint32_t max_id = operand_inst->GetSingleWordInOperand(4);
715 const analysis::Constant* max_const =
716 const_mgr->FindDeclaredConstant(max_id);
717
718 uint32_t min_id = operand_inst->GetSingleWordInOperand(3);
719 const analysis::Constant* min_const =
720 const_mgr->FindDeclaredConstant(min_id);
721
722 bool found_result = false;
723 bool result = false;
724
725 switch (cmp_opcode) {
726 case SpvOpFOrdLessThan:
727 case SpvOpFUnordLessThan:
728 case SpvOpFOrdGreaterThanEqual:
729 case SpvOpFUnordGreaterThanEqual:
730 if (constants[0]) {
731 if (min_const) {
732 if (constants[0]->GetValueAsDouble() <
733 min_const->GetValueAsDouble()) {
734 found_result = true;
735 result = (cmp_opcode == SpvOpFOrdLessThan ||
736 cmp_opcode == SpvOpFUnordLessThan);
737 }
738 }
739 if (max_const) {
740 if (constants[0]->GetValueAsDouble() >=
741 max_const->GetValueAsDouble()) {
742 found_result = true;
743 result = !(cmp_opcode == SpvOpFOrdLessThan ||
744 cmp_opcode == SpvOpFUnordLessThan);
745 }
746 }
747 }
748
749 if (constants[1]) {
750 if (max_const) {
751 if (max_const->GetValueAsDouble() <
752 constants[1]->GetValueAsDouble()) {
753 found_result = true;
754 result = (cmp_opcode == SpvOpFOrdLessThan ||
755 cmp_opcode == SpvOpFUnordLessThan);
756 }
757 }
758
759 if (min_const) {
760 if (min_const->GetValueAsDouble() >=
761 constants[1]->GetValueAsDouble()) {
762 found_result = true;
763 result = !(cmp_opcode == SpvOpFOrdLessThan ||
764 cmp_opcode == SpvOpFUnordLessThan);
765 }
766 }
767 }
768 break;
769 case SpvOpFOrdGreaterThan:
770 case SpvOpFUnordGreaterThan:
771 case SpvOpFOrdLessThanEqual:
772 case SpvOpFUnordLessThanEqual:
773 if (constants[0]) {
774 if (min_const) {
775 if (constants[0]->GetValueAsDouble() <=
776 min_const->GetValueAsDouble()) {
777 found_result = true;
778 result = (cmp_opcode == SpvOpFOrdLessThanEqual ||
779 cmp_opcode == SpvOpFUnordLessThanEqual);
780 }
781 }
782 if (max_const) {
783 if (constants[0]->GetValueAsDouble() >
784 max_const->GetValueAsDouble()) {
785 found_result = true;
786 result = !(cmp_opcode == SpvOpFOrdLessThanEqual ||
787 cmp_opcode == SpvOpFUnordLessThanEqual);
788 }
789 }
790 }
791
792 if (constants[1]) {
793 if (max_const) {
794 if (max_const->GetValueAsDouble() <=
795 constants[1]->GetValueAsDouble()) {
796 found_result = true;
797 result = (cmp_opcode == SpvOpFOrdLessThanEqual ||
798 cmp_opcode == SpvOpFUnordLessThanEqual);
799 }
800 }
801
802 if (min_const) {
803 if (min_const->GetValueAsDouble() >
804 constants[1]->GetValueAsDouble()) {
805 found_result = true;
806 result = !(cmp_opcode == SpvOpFOrdLessThanEqual ||
807 cmp_opcode == SpvOpFUnordLessThanEqual);
808 }
809 }
810 }
811 break;
812 default:
813 return nullptr;
814 }
815
816 if (!found_result) {
817 return nullptr;
818 }
819
820 const analysis::Type* bool_type =
821 context->get_type_mgr()->GetType(inst->type_id());
822 const analysis::Constant* result_const =
823 const_mgr->GetConstant(bool_type, {static_cast<uint32_t>(result)});
824 assert(result_const);
825 return result_const;
826 };
827}
828
Ben Claytond0f684e2019-08-30 22:36:08 +0100829ConstantFoldingRule FoldFMix() {
830 return [](IRContext* context, Instruction* inst,
831 const std::vector<const analysis::Constant*>& constants)
832 -> const analysis::Constant* {
833 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
834 assert(inst->opcode() == SpvOpExtInst &&
835 "Expecting an extended instruction.");
836 assert(inst->GetSingleWordInOperand(0) ==
837 context->get_feature_mgr()->GetExtInstImportId_GLSLstd450() &&
838 "Expecting a GLSLstd450 extended instruction.");
839 assert(inst->GetSingleWordInOperand(1) == GLSLstd450FMix &&
840 "Expecting and FMix instruction.");
841
842 if (!inst->IsFloatingPointFoldingAllowed()) {
843 return nullptr;
844 }
845
846 // Make sure all FMix operands are constants.
847 for (uint32_t i = 1; i < 4; i++) {
848 if (constants[i] == nullptr) {
849 return nullptr;
850 }
851 }
852
853 const analysis::Constant* one;
Ben Claytond552f632019-11-18 11:18:41 +0000854 bool is_vector = false;
855 const analysis::Type* result_type = constants[1]->type();
856 const analysis::Type* base_type = result_type;
857 if (base_type->AsVector()) {
858 is_vector = true;
859 base_type = base_type->AsVector()->element_type();
860 }
861 assert(base_type->AsFloat() != nullptr &&
862 "FMix is suppose to act on floats or vectors of floats.");
863
864 if (base_type->AsFloat()->width() == 32) {
865 one = const_mgr->GetConstant(base_type,
Ben Claytond0f684e2019-08-30 22:36:08 +0100866 utils::FloatProxy<float>(1.0f).GetWords());
867 } else {
Ben Claytond552f632019-11-18 11:18:41 +0000868 one = const_mgr->GetConstant(base_type,
Ben Claytond0f684e2019-08-30 22:36:08 +0100869 utils::FloatProxy<double>(1.0).GetWords());
870 }
871
Ben Claytond552f632019-11-18 11:18:41 +0000872 if (is_vector) {
873 uint32_t one_id = const_mgr->GetDefiningInstruction(one)->result_id();
874 one =
875 const_mgr->GetConstant(result_type, std::vector<uint32_t>(4, one_id));
876 }
877
878 const analysis::Constant* temp1 = FoldFPBinaryOp(
879 FOLD_FPARITH_OP(-), inst->type_id(), {one, constants[3]}, context);
Ben Claytond0f684e2019-08-30 22:36:08 +0100880 if (temp1 == nullptr) {
881 return nullptr;
882 }
883
Ben Claytond552f632019-11-18 11:18:41 +0000884 const analysis::Constant* temp2 = FoldFPBinaryOp(
885 FOLD_FPARITH_OP(*), inst->type_id(), {constants[1], temp1}, context);
Ben Claytond0f684e2019-08-30 22:36:08 +0100886 if (temp2 == nullptr) {
887 return nullptr;
888 }
Ben Claytond552f632019-11-18 11:18:41 +0000889 const analysis::Constant* temp3 =
890 FoldFPBinaryOp(FOLD_FPARITH_OP(*), inst->type_id(),
891 {constants[2], constants[3]}, context);
Ben Claytond0f684e2019-08-30 22:36:08 +0100892 if (temp3 == nullptr) {
893 return nullptr;
894 }
Ben Claytond552f632019-11-18 11:18:41 +0000895 return FoldFPBinaryOp(FOLD_FPARITH_OP(+), inst->type_id(), {temp2, temp3},
896 context);
Ben Claytond0f684e2019-08-30 22:36:08 +0100897 };
898}
899
Ben Claytond552f632019-11-18 11:18:41 +0000900template <class IntType>
901IntType FoldIClamp(IntType x, IntType min_val, IntType max_val) {
902 if (x < min_val) {
903 x = min_val;
904 }
905 if (x > max_val) {
906 x = max_val;
907 }
908 return x;
909}
910
911const analysis::Constant* FoldMin(const analysis::Type* result_type,
912 const analysis::Constant* a,
913 const analysis::Constant* b,
914 analysis::ConstantManager*) {
915 if (const analysis::Integer* int_type = result_type->AsInteger()) {
916 if (int_type->width() == 32) {
917 if (int_type->IsSigned()) {
918 int32_t va = a->GetS32();
919 int32_t vb = b->GetS32();
920 return (va < vb ? a : b);
921 } else {
922 uint32_t va = a->GetU32();
923 uint32_t vb = b->GetU32();
924 return (va < vb ? a : b);
925 }
926 } else if (int_type->width() == 64) {
927 if (int_type->IsSigned()) {
928 int64_t va = a->GetS64();
929 int64_t vb = b->GetS64();
930 return (va < vb ? a : b);
931 } else {
932 uint64_t va = a->GetU64();
933 uint64_t vb = b->GetU64();
934 return (va < vb ? a : b);
935 }
936 }
937 } else if (const analysis::Float* float_type = result_type->AsFloat()) {
938 if (float_type->width() == 32) {
939 float va = a->GetFloat();
940 float vb = b->GetFloat();
941 return (va < vb ? a : b);
942 } else if (float_type->width() == 64) {
943 double va = a->GetDouble();
944 double vb = b->GetDouble();
945 return (va < vb ? a : b);
946 }
947 }
948 return nullptr;
949}
950
951const analysis::Constant* FoldMax(const analysis::Type* result_type,
952 const analysis::Constant* a,
953 const analysis::Constant* b,
954 analysis::ConstantManager*) {
955 if (const analysis::Integer* int_type = result_type->AsInteger()) {
956 if (int_type->width() == 32) {
957 if (int_type->IsSigned()) {
958 int32_t va = a->GetS32();
959 int32_t vb = b->GetS32();
960 return (va > vb ? a : b);
961 } else {
962 uint32_t va = a->GetU32();
963 uint32_t vb = b->GetU32();
964 return (va > vb ? a : b);
965 }
966 } else if (int_type->width() == 64) {
967 if (int_type->IsSigned()) {
968 int64_t va = a->GetS64();
969 int64_t vb = b->GetS64();
970 return (va > vb ? a : b);
971 } else {
972 uint64_t va = a->GetU64();
973 uint64_t vb = b->GetU64();
974 return (va > vb ? a : b);
975 }
976 }
977 } else if (const analysis::Float* float_type = result_type->AsFloat()) {
978 if (float_type->width() == 32) {
979 float va = a->GetFloat();
980 float vb = b->GetFloat();
981 return (va > vb ? a : b);
982 } else if (float_type->width() == 64) {
983 double va = a->GetDouble();
984 double vb = b->GetDouble();
985 return (va > vb ? a : b);
986 }
987 }
988 return nullptr;
989}
990
991// Fold an clamp instruction when all three operands are constant.
992const analysis::Constant* FoldClamp1(
993 IRContext* context, Instruction* inst,
994 const std::vector<const analysis::Constant*>& constants) {
995 assert(inst->opcode() == SpvOpExtInst &&
996 "Expecting an extended instruction.");
997 assert(inst->GetSingleWordInOperand(0) ==
998 context->get_feature_mgr()->GetExtInstImportId_GLSLstd450() &&
999 "Expecting a GLSLstd450 extended instruction.");
1000
1001 // Make sure all Clamp operands are constants.
1002 for (uint32_t i = 1; i < 3; i++) {
1003 if (constants[i] == nullptr) {
1004 return nullptr;
1005 }
1006 }
1007
1008 const analysis::Constant* temp = FoldFPBinaryOp(
1009 FoldMax, inst->type_id(), {constants[1], constants[2]}, context);
1010 if (temp == nullptr) {
1011 return nullptr;
1012 }
1013 return FoldFPBinaryOp(FoldMin, inst->type_id(), {temp, constants[3]},
1014 context);
1015}
1016
1017// Fold a clamp instruction when |x >= min_val|.
1018const analysis::Constant* FoldClamp2(
1019 IRContext* context, Instruction* inst,
1020 const std::vector<const analysis::Constant*>& constants) {
1021 assert(inst->opcode() == SpvOpExtInst &&
1022 "Expecting an extended instruction.");
1023 assert(inst->GetSingleWordInOperand(0) ==
1024 context->get_feature_mgr()->GetExtInstImportId_GLSLstd450() &&
1025 "Expecting a GLSLstd450 extended instruction.");
1026
1027 const analysis::Constant* x = constants[1];
1028 const analysis::Constant* min_val = constants[2];
1029
1030 if (x == nullptr || min_val == nullptr) {
1031 return nullptr;
1032 }
1033
1034 const analysis::Constant* temp =
1035 FoldFPBinaryOp(FoldMax, inst->type_id(), {x, min_val}, context);
1036 if (temp == min_val) {
1037 // We can assume that |min_val| is less than |max_val|. Therefore, if the
1038 // result of the max operation is |min_val|, we know the result of the min
1039 // operation, even if |max_val| is not a constant.
1040 return min_val;
1041 }
1042 return nullptr;
1043}
1044
1045// Fold a clamp instruction when |x >= max_val|.
1046const analysis::Constant* FoldClamp3(
1047 IRContext* context, Instruction* inst,
1048 const std::vector<const analysis::Constant*>& constants) {
1049 assert(inst->opcode() == SpvOpExtInst &&
1050 "Expecting an extended instruction.");
1051 assert(inst->GetSingleWordInOperand(0) ==
1052 context->get_feature_mgr()->GetExtInstImportId_GLSLstd450() &&
1053 "Expecting a GLSLstd450 extended instruction.");
1054
1055 const analysis::Constant* x = constants[1];
1056 const analysis::Constant* max_val = constants[3];
1057
1058 if (x == nullptr || max_val == nullptr) {
1059 return nullptr;
1060 }
1061
1062 const analysis::Constant* temp =
1063 FoldFPBinaryOp(FoldMin, inst->type_id(), {x, max_val}, context);
1064 if (temp == max_val) {
1065 // We can assume that |min_val| is less than |max_val|. Therefore, if the
1066 // result of the max operation is |min_val|, we know the result of the min
1067 // operation, even if |max_val| is not a constant.
1068 return max_val;
1069 }
1070 return nullptr;
1071}
1072
Chris Forbescc5697f2019-01-30 11:54:08 -08001073} // namespace
1074
Ben Claytond0f684e2019-08-30 22:36:08 +01001075void ConstantFoldingRules::AddFoldingRules() {
Chris Forbescc5697f2019-01-30 11:54:08 -08001076 // Add all folding rules to the list for the opcodes to which they apply.
1077 // Note that the order in which rules are added to the list matters. If a rule
1078 // applies to the instruction, the rest of the rules will not be attempted.
1079 // Take that into consideration.
1080
1081 rules_[SpvOpCompositeConstruct].push_back(FoldCompositeWithConstants());
1082
1083 rules_[SpvOpCompositeExtract].push_back(FoldExtractWithConstants());
1084
1085 rules_[SpvOpConvertFToS].push_back(FoldFToI());
1086 rules_[SpvOpConvertFToU].push_back(FoldFToI());
1087 rules_[SpvOpConvertSToF].push_back(FoldIToF());
1088 rules_[SpvOpConvertUToF].push_back(FoldIToF());
1089
1090 rules_[SpvOpDot].push_back(FoldOpDotWithConstants());
1091 rules_[SpvOpFAdd].push_back(FoldFAdd());
1092 rules_[SpvOpFDiv].push_back(FoldFDiv());
1093 rules_[SpvOpFMul].push_back(FoldFMul());
1094 rules_[SpvOpFSub].push_back(FoldFSub());
1095
1096 rules_[SpvOpFOrdEqual].push_back(FoldFOrdEqual());
1097
1098 rules_[SpvOpFUnordEqual].push_back(FoldFUnordEqual());
1099
1100 rules_[SpvOpFOrdNotEqual].push_back(FoldFOrdNotEqual());
1101
1102 rules_[SpvOpFUnordNotEqual].push_back(FoldFUnordNotEqual());
1103
1104 rules_[SpvOpFOrdLessThan].push_back(FoldFOrdLessThan());
1105 rules_[SpvOpFOrdLessThan].push_back(
1106 FoldFClampFeedingCompare(SpvOpFOrdLessThan));
1107
1108 rules_[SpvOpFUnordLessThan].push_back(FoldFUnordLessThan());
1109 rules_[SpvOpFUnordLessThan].push_back(
1110 FoldFClampFeedingCompare(SpvOpFUnordLessThan));
1111
1112 rules_[SpvOpFOrdGreaterThan].push_back(FoldFOrdGreaterThan());
1113 rules_[SpvOpFOrdGreaterThan].push_back(
1114 FoldFClampFeedingCompare(SpvOpFOrdGreaterThan));
1115
1116 rules_[SpvOpFUnordGreaterThan].push_back(FoldFUnordGreaterThan());
1117 rules_[SpvOpFUnordGreaterThan].push_back(
1118 FoldFClampFeedingCompare(SpvOpFUnordGreaterThan));
1119
1120 rules_[SpvOpFOrdLessThanEqual].push_back(FoldFOrdLessThanEqual());
1121 rules_[SpvOpFOrdLessThanEqual].push_back(
1122 FoldFClampFeedingCompare(SpvOpFOrdLessThanEqual));
1123
1124 rules_[SpvOpFUnordLessThanEqual].push_back(FoldFUnordLessThanEqual());
1125 rules_[SpvOpFUnordLessThanEqual].push_back(
1126 FoldFClampFeedingCompare(SpvOpFUnordLessThanEqual));
1127
1128 rules_[SpvOpFOrdGreaterThanEqual].push_back(FoldFOrdGreaterThanEqual());
1129 rules_[SpvOpFOrdGreaterThanEqual].push_back(
1130 FoldFClampFeedingCompare(SpvOpFOrdGreaterThanEqual));
1131
1132 rules_[SpvOpFUnordGreaterThanEqual].push_back(FoldFUnordGreaterThanEqual());
1133 rules_[SpvOpFUnordGreaterThanEqual].push_back(
1134 FoldFClampFeedingCompare(SpvOpFUnordGreaterThanEqual));
1135
1136 rules_[SpvOpVectorShuffle].push_back(FoldVectorShuffleWithConstants());
1137 rules_[SpvOpVectorTimesScalar].push_back(FoldVectorTimesScalar());
1138
1139 rules_[SpvOpFNegate].push_back(FoldFNegate());
Ben Claytonb73b7602019-07-29 13:56:13 +01001140 rules_[SpvOpQuantizeToF16].push_back(FoldQuantizeToF16());
Ben Claytond0f684e2019-08-30 22:36:08 +01001141
1142 // Add rules for GLSLstd450
1143 FeatureManager* feature_manager = context_->get_feature_mgr();
1144 uint32_t ext_inst_glslstd450_id =
1145 feature_manager->GetExtInstImportId_GLSLstd450();
1146 if (ext_inst_glslstd450_id != 0) {
1147 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450FMix}].push_back(FoldFMix());
Ben Claytond552f632019-11-18 11:18:41 +00001148 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450SMin}].push_back(
1149 FoldFPBinaryOp(FoldMin));
1150 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450UMin}].push_back(
1151 FoldFPBinaryOp(FoldMin));
1152 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450FMin}].push_back(
1153 FoldFPBinaryOp(FoldMin));
1154 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450SMax}].push_back(
1155 FoldFPBinaryOp(FoldMax));
1156 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450UMax}].push_back(
1157 FoldFPBinaryOp(FoldMax));
1158 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450FMax}].push_back(
1159 FoldFPBinaryOp(FoldMax));
1160 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450UClamp}].push_back(
1161 FoldClamp1);
1162 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450UClamp}].push_back(
1163 FoldClamp2);
1164 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450UClamp}].push_back(
1165 FoldClamp3);
1166 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450SClamp}].push_back(
1167 FoldClamp1);
1168 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450SClamp}].push_back(
1169 FoldClamp2);
1170 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450SClamp}].push_back(
1171 FoldClamp3);
1172 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450FClamp}].push_back(
1173 FoldClamp1);
1174 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450FClamp}].push_back(
1175 FoldClamp2);
1176 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450FClamp}].push_back(
1177 FoldClamp3);
Ben Claytond0f684e2019-08-30 22:36:08 +01001178 }
Chris Forbescc5697f2019-01-30 11:54:08 -08001179}
1180} // namespace opt
1181} // namespace spvtools