blob: 06a1a81e6c56efcf2cfcf4a0872667465da930c2 [file] [log] [blame]
Chris Forbescc5697f2019-01-30 11:54:08 -08001// Copyright (c) 2018 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#include "source/opt/const_folding_rules.h"
16
17#include "source/opt/ir_context.h"
18
19namespace spvtools {
20namespace opt {
21namespace {
22
23const uint32_t kExtractCompositeIdInIdx = 0;
24
25// Returns true if |type| is Float or a vector of Float.
26bool HasFloatingPoint(const analysis::Type* type) {
27 if (type->AsFloat()) {
28 return true;
29 } else if (const analysis::Vector* vec_type = type->AsVector()) {
30 return vec_type->element_type()->AsFloat() != nullptr;
31 }
32
33 return false;
34}
35
36// Folds an OpcompositeExtract where input is a composite constant.
37ConstantFoldingRule FoldExtractWithConstants() {
38 return [](IRContext* context, Instruction* inst,
39 const std::vector<const analysis::Constant*>& constants)
40 -> const analysis::Constant* {
41 const analysis::Constant* c = constants[kExtractCompositeIdInIdx];
42 if (c == nullptr) {
43 return nullptr;
44 }
45
46 for (uint32_t i = 1; i < inst->NumInOperands(); ++i) {
47 uint32_t element_index = inst->GetSingleWordInOperand(i);
48 if (c->AsNullConstant()) {
49 // Return Null for the return type.
50 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
51 analysis::TypeManager* type_mgr = context->get_type_mgr();
52 return const_mgr->GetConstant(type_mgr->GetType(inst->type_id()), {});
53 }
54
55 auto cc = c->AsCompositeConstant();
56 assert(cc != nullptr);
57 auto components = cc->GetComponents();
Ben Claytond0f684e2019-08-30 22:36:08 +010058 // Protect against invalid IR. Refuse to fold if the index is out
59 // of bounds.
60 if (element_index >= components.size()) return nullptr;
Chris Forbescc5697f2019-01-30 11:54:08 -080061 c = components[element_index];
62 }
63 return c;
64 };
65}
66
67ConstantFoldingRule FoldVectorShuffleWithConstants() {
68 return [](IRContext* context, Instruction* inst,
69 const std::vector<const analysis::Constant*>& constants)
70 -> const analysis::Constant* {
71 assert(inst->opcode() == SpvOpVectorShuffle);
72 const analysis::Constant* c1 = constants[0];
73 const analysis::Constant* c2 = constants[1];
74 if (c1 == nullptr || c2 == nullptr) {
75 return nullptr;
76 }
77
78 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
79 const analysis::Type* element_type = c1->type()->AsVector()->element_type();
80
81 std::vector<const analysis::Constant*> c1_components;
82 if (const analysis::VectorConstant* vec_const = c1->AsVectorConstant()) {
83 c1_components = vec_const->GetComponents();
84 } else {
85 assert(c1->AsNullConstant());
86 const analysis::Constant* element =
87 const_mgr->GetConstant(element_type, {});
88 c1_components.resize(c1->type()->AsVector()->element_count(), element);
89 }
90 std::vector<const analysis::Constant*> c2_components;
91 if (const analysis::VectorConstant* vec_const = c2->AsVectorConstant()) {
92 c2_components = vec_const->GetComponents();
93 } else {
94 assert(c2->AsNullConstant());
95 const analysis::Constant* element =
96 const_mgr->GetConstant(element_type, {});
97 c2_components.resize(c2->type()->AsVector()->element_count(), element);
98 }
99
100 std::vector<uint32_t> ids;
101 const uint32_t undef_literal_value = 0xffffffff;
102 for (uint32_t i = 2; i < inst->NumInOperands(); ++i) {
103 uint32_t index = inst->GetSingleWordInOperand(i);
104 if (index == undef_literal_value) {
105 // Don't fold shuffle with undef literal value.
106 return nullptr;
107 } else if (index < c1_components.size()) {
108 Instruction* member_inst =
109 const_mgr->GetDefiningInstruction(c1_components[index]);
110 ids.push_back(member_inst->result_id());
111 } else {
112 Instruction* member_inst = const_mgr->GetDefiningInstruction(
113 c2_components[index - c1_components.size()]);
114 ids.push_back(member_inst->result_id());
115 }
116 }
117
118 analysis::TypeManager* type_mgr = context->get_type_mgr();
119 return const_mgr->GetConstant(type_mgr->GetType(inst->type_id()), ids);
120 };
121}
122
123ConstantFoldingRule FoldVectorTimesScalar() {
124 return [](IRContext* context, Instruction* inst,
125 const std::vector<const analysis::Constant*>& constants)
126 -> const analysis::Constant* {
127 assert(inst->opcode() == SpvOpVectorTimesScalar);
128 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
129 analysis::TypeManager* type_mgr = context->get_type_mgr();
130
131 if (!inst->IsFloatingPointFoldingAllowed()) {
132 if (HasFloatingPoint(type_mgr->GetType(inst->type_id()))) {
133 return nullptr;
134 }
135 }
136
137 const analysis::Constant* c1 = constants[0];
138 const analysis::Constant* c2 = constants[1];
139
140 if (c1 && c1->IsZero()) {
141 return c1;
142 }
143
144 if (c2 && c2->IsZero()) {
145 // Get or create the NullConstant for this type.
146 std::vector<uint32_t> ids;
147 return const_mgr->GetConstant(type_mgr->GetType(inst->type_id()), ids);
148 }
149
150 if (c1 == nullptr || c2 == nullptr) {
151 return nullptr;
152 }
153
154 // Check result type.
155 const analysis::Type* result_type = type_mgr->GetType(inst->type_id());
156 const analysis::Vector* vector_type = result_type->AsVector();
157 assert(vector_type != nullptr);
158 const analysis::Type* element_type = vector_type->element_type();
159 assert(element_type != nullptr);
160 const analysis::Float* float_type = element_type->AsFloat();
161 assert(float_type != nullptr);
162
163 // Check types of c1 and c2.
164 assert(c1->type()->AsVector() == vector_type);
165 assert(c1->type()->AsVector()->element_type() == element_type &&
166 c2->type() == element_type);
167
168 // Get a float vector that is the result of vector-times-scalar.
169 std::vector<const analysis::Constant*> c1_components =
170 c1->GetVectorComponents(const_mgr);
171 std::vector<uint32_t> ids;
172 if (float_type->width() == 32) {
173 float scalar = c2->GetFloat();
174 for (uint32_t i = 0; i < c1_components.size(); ++i) {
175 utils::FloatProxy<float> result(c1_components[i]->GetFloat() * scalar);
176 std::vector<uint32_t> words = result.GetWords();
177 const analysis::Constant* new_elem =
178 const_mgr->GetConstant(float_type, words);
179 ids.push_back(const_mgr->GetDefiningInstruction(new_elem)->result_id());
180 }
181 return const_mgr->GetConstant(vector_type, ids);
182 } else if (float_type->width() == 64) {
183 double scalar = c2->GetDouble();
184 for (uint32_t i = 0; i < c1_components.size(); ++i) {
185 utils::FloatProxy<double> result(c1_components[i]->GetDouble() *
186 scalar);
187 std::vector<uint32_t> words = result.GetWords();
188 const analysis::Constant* new_elem =
189 const_mgr->GetConstant(float_type, words);
190 ids.push_back(const_mgr->GetDefiningInstruction(new_elem)->result_id());
191 }
192 return const_mgr->GetConstant(vector_type, ids);
193 }
194 return nullptr;
195 };
196}
197
198ConstantFoldingRule FoldCompositeWithConstants() {
199 // Folds an OpCompositeConstruct where all of the inputs are constants to a
200 // constant. A new constant is created if necessary.
201 return [](IRContext* context, Instruction* inst,
202 const std::vector<const analysis::Constant*>& constants)
203 -> const analysis::Constant* {
204 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
205 analysis::TypeManager* type_mgr = context->get_type_mgr();
206 const analysis::Type* new_type = type_mgr->GetType(inst->type_id());
207 Instruction* type_inst =
208 context->get_def_use_mgr()->GetDef(inst->type_id());
209
210 std::vector<uint32_t> ids;
211 for (uint32_t i = 0; i < constants.size(); ++i) {
212 const analysis::Constant* element_const = constants[i];
213 if (element_const == nullptr) {
214 return nullptr;
215 }
216
217 uint32_t component_type_id = 0;
218 if (type_inst->opcode() == SpvOpTypeStruct) {
219 component_type_id = type_inst->GetSingleWordInOperand(i);
220 } else if (type_inst->opcode() == SpvOpTypeArray) {
221 component_type_id = type_inst->GetSingleWordInOperand(0);
222 }
223
224 uint32_t element_id =
225 const_mgr->FindDeclaredConstant(element_const, component_type_id);
226 if (element_id == 0) {
227 return nullptr;
228 }
229 ids.push_back(element_id);
230 }
231 return const_mgr->GetConstant(new_type, ids);
232 };
233}
234
235// The interface for a function that returns the result of applying a scalar
236// floating-point binary operation on |a| and |b|. The type of the return value
237// will be |type|. The input constants must also be of type |type|.
238using UnaryScalarFoldingRule = std::function<const analysis::Constant*(
239 const analysis::Type* result_type, const analysis::Constant* a,
240 analysis::ConstantManager*)>;
241
242// The interface for a function that returns the result of applying a scalar
243// floating-point binary operation on |a| and |b|. The type of the return value
244// will be |type|. The input constants must also be of type |type|.
245using BinaryScalarFoldingRule = std::function<const analysis::Constant*(
246 const analysis::Type* result_type, const analysis::Constant* a,
247 const analysis::Constant* b, analysis::ConstantManager*)>;
248
249// Returns a |ConstantFoldingRule| that folds unary floating point scalar ops
250// using |scalar_rule| and unary float point vectors ops by applying
251// |scalar_rule| to the elements of the vector. The |ConstantFoldingRule|
252// that is returned assumes that |constants| contains 1 entry. If they are
253// not |nullptr|, then their type is either |Float| or |Integer| or a |Vector|
254// whose element type is |Float| or |Integer|.
255ConstantFoldingRule FoldFPUnaryOp(UnaryScalarFoldingRule scalar_rule) {
256 return [scalar_rule](IRContext* context, Instruction* inst,
257 const std::vector<const analysis::Constant*>& constants)
258 -> const analysis::Constant* {
259 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
260 analysis::TypeManager* type_mgr = context->get_type_mgr();
261 const analysis::Type* result_type = type_mgr->GetType(inst->type_id());
262 const analysis::Vector* vector_type = result_type->AsVector();
263
264 if (!inst->IsFloatingPointFoldingAllowed()) {
265 return nullptr;
266 }
267
268 if (constants[0] == nullptr) {
269 return nullptr;
270 }
271
272 if (vector_type != nullptr) {
273 std::vector<const analysis::Constant*> a_components;
274 std::vector<const analysis::Constant*> results_components;
275
276 a_components = constants[0]->GetVectorComponents(const_mgr);
277
278 // Fold each component of the vector.
279 for (uint32_t i = 0; i < a_components.size(); ++i) {
280 results_components.push_back(scalar_rule(vector_type->element_type(),
281 a_components[i], const_mgr));
282 if (results_components[i] == nullptr) {
283 return nullptr;
284 }
285 }
286
287 // Build the constant object and return it.
288 std::vector<uint32_t> ids;
289 for (const analysis::Constant* member : results_components) {
290 ids.push_back(const_mgr->GetDefiningInstruction(member)->result_id());
291 }
292 return const_mgr->GetConstant(vector_type, ids);
293 } else {
294 return scalar_rule(result_type, constants[0], const_mgr);
295 }
296 };
297}
298
299// Returns a |ConstantFoldingRule| that folds floating point scalars using
300// |scalar_rule| and vectors of floating point by applying |scalar_rule| to the
301// elements of the vector. The |ConstantFoldingRule| that is returned assumes
302// that |constants| contains 2 entries. If they are not |nullptr|, then their
303// type is either |Float| or a |Vector| whose element type is |Float|.
304ConstantFoldingRule FoldFPBinaryOp(BinaryScalarFoldingRule scalar_rule) {
305 return [scalar_rule](IRContext* context, Instruction* inst,
306 const std::vector<const analysis::Constant*>& constants)
307 -> const analysis::Constant* {
308 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
309 analysis::TypeManager* type_mgr = context->get_type_mgr();
310 const analysis::Type* result_type = type_mgr->GetType(inst->type_id());
311 const analysis::Vector* vector_type = result_type->AsVector();
312
313 if (!inst->IsFloatingPointFoldingAllowed()) {
314 return nullptr;
315 }
316
317 if (constants[0] == nullptr || constants[1] == nullptr) {
318 return nullptr;
319 }
320
321 if (vector_type != nullptr) {
322 std::vector<const analysis::Constant*> a_components;
323 std::vector<const analysis::Constant*> b_components;
324 std::vector<const analysis::Constant*> results_components;
325
326 a_components = constants[0]->GetVectorComponents(const_mgr);
327 b_components = constants[1]->GetVectorComponents(const_mgr);
328
329 // Fold each component of the vector.
330 for (uint32_t i = 0; i < a_components.size(); ++i) {
331 results_components.push_back(scalar_rule(vector_type->element_type(),
332 a_components[i],
333 b_components[i], const_mgr));
334 if (results_components[i] == nullptr) {
335 return nullptr;
336 }
337 }
338
339 // Build the constant object and return it.
340 std::vector<uint32_t> ids;
341 for (const analysis::Constant* member : results_components) {
342 ids.push_back(const_mgr->GetDefiningInstruction(member)->result_id());
343 }
344 return const_mgr->GetConstant(vector_type, ids);
345 } else {
346 return scalar_rule(result_type, constants[0], constants[1], const_mgr);
347 }
348 };
349}
350
351// This macro defines a |UnaryScalarFoldingRule| that performs float to
352// integer conversion.
353// TODO(greg-lunarg): Support for 64-bit integer types.
354UnaryScalarFoldingRule FoldFToIOp() {
355 return [](const analysis::Type* result_type, const analysis::Constant* a,
356 analysis::ConstantManager* const_mgr) -> const analysis::Constant* {
357 assert(result_type != nullptr && a != nullptr);
358 const analysis::Integer* integer_type = result_type->AsInteger();
359 const analysis::Float* float_type = a->type()->AsFloat();
360 assert(float_type != nullptr);
361 assert(integer_type != nullptr);
362 if (integer_type->width() != 32) return nullptr;
363 if (float_type->width() == 32) {
364 float fa = a->GetFloat();
365 uint32_t result = integer_type->IsSigned()
366 ? static_cast<uint32_t>(static_cast<int32_t>(fa))
367 : static_cast<uint32_t>(fa);
368 std::vector<uint32_t> words = {result};
369 return const_mgr->GetConstant(result_type, words);
370 } else if (float_type->width() == 64) {
371 double fa = a->GetDouble();
372 uint32_t result = integer_type->IsSigned()
373 ? static_cast<uint32_t>(static_cast<int32_t>(fa))
374 : static_cast<uint32_t>(fa);
375 std::vector<uint32_t> words = {result};
376 return const_mgr->GetConstant(result_type, words);
377 }
378 return nullptr;
379 };
380}
381
382// This function defines a |UnaryScalarFoldingRule| that performs integer to
383// float conversion.
384// TODO(greg-lunarg): Support for 64-bit integer types.
385UnaryScalarFoldingRule FoldIToFOp() {
386 return [](const analysis::Type* result_type, const analysis::Constant* a,
387 analysis::ConstantManager* const_mgr) -> const analysis::Constant* {
388 assert(result_type != nullptr && a != nullptr);
389 const analysis::Integer* integer_type = a->type()->AsInteger();
390 const analysis::Float* float_type = result_type->AsFloat();
391 assert(float_type != nullptr);
392 assert(integer_type != nullptr);
393 if (integer_type->width() != 32) return nullptr;
394 uint32_t ua = a->GetU32();
395 if (float_type->width() == 32) {
396 float result_val = integer_type->IsSigned()
397 ? static_cast<float>(static_cast<int32_t>(ua))
398 : static_cast<float>(ua);
399 utils::FloatProxy<float> result(result_val);
400 std::vector<uint32_t> words = {result.data()};
401 return const_mgr->GetConstant(result_type, words);
402 } else if (float_type->width() == 64) {
403 double result_val = integer_type->IsSigned()
404 ? static_cast<double>(static_cast<int32_t>(ua))
405 : static_cast<double>(ua);
406 utils::FloatProxy<double> result(result_val);
407 std::vector<uint32_t> words = result.GetWords();
408 return const_mgr->GetConstant(result_type, words);
409 }
410 return nullptr;
411 };
412}
413
Ben Claytonb73b7602019-07-29 13:56:13 +0100414// This defines a |UnaryScalarFoldingRule| that performs |OpQuantizeToF16|.
415UnaryScalarFoldingRule FoldQuantizeToF16Scalar() {
416 return [](const analysis::Type* result_type, const analysis::Constant* a,
417 analysis::ConstantManager* const_mgr) -> const analysis::Constant* {
418 assert(result_type != nullptr && a != nullptr);
419 const analysis::Float* float_type = a->type()->AsFloat();
420 assert(float_type != nullptr);
421 if (float_type->width() != 32) {
422 return nullptr;
423 }
424
425 float fa = a->GetFloat();
426 utils::HexFloat<utils::FloatProxy<float>> orignal(fa);
427 utils::HexFloat<utils::FloatProxy<utils::Float16>> quantized(0);
428 utils::HexFloat<utils::FloatProxy<float>> result(0.0f);
429 orignal.castTo(quantized, utils::round_direction::kToZero);
430 quantized.castTo(result, utils::round_direction::kToZero);
431 std::vector<uint32_t> words = {result.getBits()};
432 return const_mgr->GetConstant(result_type, words);
433 };
434}
435
Chris Forbescc5697f2019-01-30 11:54:08 -0800436// This macro defines a |BinaryScalarFoldingRule| that applies |op|. The
437// operator |op| must work for both float and double, and use syntax "f1 op f2".
438#define FOLD_FPARITH_OP(op) \
439 [](const analysis::Type* result_type, const analysis::Constant* a, \
440 const analysis::Constant* b, \
441 analysis::ConstantManager* const_mgr_in_macro) \
442 -> const analysis::Constant* { \
443 assert(result_type != nullptr && a != nullptr && b != nullptr); \
444 assert(result_type == a->type() && result_type == b->type()); \
445 const analysis::Float* float_type_in_macro = result_type->AsFloat(); \
446 assert(float_type_in_macro != nullptr); \
447 if (float_type_in_macro->width() == 32) { \
448 float fa = a->GetFloat(); \
449 float fb = b->GetFloat(); \
450 utils::FloatProxy<float> result_in_macro(fa op fb); \
451 std::vector<uint32_t> words_in_macro = result_in_macro.GetWords(); \
452 return const_mgr_in_macro->GetConstant(result_type, words_in_macro); \
453 } else if (float_type_in_macro->width() == 64) { \
454 double fa = a->GetDouble(); \
455 double fb = b->GetDouble(); \
456 utils::FloatProxy<double> result_in_macro(fa op fb); \
457 std::vector<uint32_t> words_in_macro = result_in_macro.GetWords(); \
458 return const_mgr_in_macro->GetConstant(result_type, words_in_macro); \
459 } \
460 return nullptr; \
461 }
462
463// Define the folding rule for conversion between floating point and integer
464ConstantFoldingRule FoldFToI() { return FoldFPUnaryOp(FoldFToIOp()); }
465ConstantFoldingRule FoldIToF() { return FoldFPUnaryOp(FoldIToFOp()); }
Ben Claytonb73b7602019-07-29 13:56:13 +0100466ConstantFoldingRule FoldQuantizeToF16() {
467 return FoldFPUnaryOp(FoldQuantizeToF16Scalar());
468}
Chris Forbescc5697f2019-01-30 11:54:08 -0800469
470// Define the folding rules for subtraction, addition, multiplication, and
471// division for floating point values.
472ConstantFoldingRule FoldFSub() { return FoldFPBinaryOp(FOLD_FPARITH_OP(-)); }
473ConstantFoldingRule FoldFAdd() { return FoldFPBinaryOp(FOLD_FPARITH_OP(+)); }
474ConstantFoldingRule FoldFMul() { return FoldFPBinaryOp(FOLD_FPARITH_OP(*)); }
475ConstantFoldingRule FoldFDiv() { return FoldFPBinaryOp(FOLD_FPARITH_OP(/)); }
476
477bool CompareFloatingPoint(bool op_result, bool op_unordered,
478 bool need_ordered) {
479 if (need_ordered) {
480 // operands are ordered and Operand 1 is |op| Operand 2
481 return !op_unordered && op_result;
482 } else {
483 // operands are unordered or Operand 1 is |op| Operand 2
484 return op_unordered || op_result;
485 }
486}
487
488// This macro defines a |BinaryScalarFoldingRule| that applies |op|. The
489// operator |op| must work for both float and double, and use syntax "f1 op f2".
490#define FOLD_FPCMP_OP(op, ord) \
491 [](const analysis::Type* result_type, const analysis::Constant* a, \
492 const analysis::Constant* b, \
493 analysis::ConstantManager* const_mgr) -> const analysis::Constant* { \
494 assert(result_type != nullptr && a != nullptr && b != nullptr); \
495 assert(result_type->AsBool()); \
496 assert(a->type() == b->type()); \
497 const analysis::Float* float_type = a->type()->AsFloat(); \
498 assert(float_type != nullptr); \
499 if (float_type->width() == 32) { \
500 float fa = a->GetFloat(); \
501 float fb = b->GetFloat(); \
502 bool result = CompareFloatingPoint( \
503 fa op fb, std::isnan(fa) || std::isnan(fb), ord); \
504 std::vector<uint32_t> words = {uint32_t(result)}; \
505 return const_mgr->GetConstant(result_type, words); \
506 } else if (float_type->width() == 64) { \
507 double fa = a->GetDouble(); \
508 double fb = b->GetDouble(); \
509 bool result = CompareFloatingPoint( \
510 fa op fb, std::isnan(fa) || std::isnan(fb), ord); \
511 std::vector<uint32_t> words = {uint32_t(result)}; \
512 return const_mgr->GetConstant(result_type, words); \
513 } \
514 return nullptr; \
515 }
516
517// Define the folding rules for ordered and unordered comparison for floating
518// point values.
519ConstantFoldingRule FoldFOrdEqual() {
520 return FoldFPBinaryOp(FOLD_FPCMP_OP(==, true));
521}
522ConstantFoldingRule FoldFUnordEqual() {
523 return FoldFPBinaryOp(FOLD_FPCMP_OP(==, false));
524}
525ConstantFoldingRule FoldFOrdNotEqual() {
526 return FoldFPBinaryOp(FOLD_FPCMP_OP(!=, true));
527}
528ConstantFoldingRule FoldFUnordNotEqual() {
529 return FoldFPBinaryOp(FOLD_FPCMP_OP(!=, false));
530}
531ConstantFoldingRule FoldFOrdLessThan() {
532 return FoldFPBinaryOp(FOLD_FPCMP_OP(<, true));
533}
534ConstantFoldingRule FoldFUnordLessThan() {
535 return FoldFPBinaryOp(FOLD_FPCMP_OP(<, false));
536}
537ConstantFoldingRule FoldFOrdGreaterThan() {
538 return FoldFPBinaryOp(FOLD_FPCMP_OP(>, true));
539}
540ConstantFoldingRule FoldFUnordGreaterThan() {
541 return FoldFPBinaryOp(FOLD_FPCMP_OP(>, false));
542}
543ConstantFoldingRule FoldFOrdLessThanEqual() {
544 return FoldFPBinaryOp(FOLD_FPCMP_OP(<=, true));
545}
546ConstantFoldingRule FoldFUnordLessThanEqual() {
547 return FoldFPBinaryOp(FOLD_FPCMP_OP(<=, false));
548}
549ConstantFoldingRule FoldFOrdGreaterThanEqual() {
550 return FoldFPBinaryOp(FOLD_FPCMP_OP(>=, true));
551}
552ConstantFoldingRule FoldFUnordGreaterThanEqual() {
553 return FoldFPBinaryOp(FOLD_FPCMP_OP(>=, false));
554}
555
556// Folds an OpDot where all of the inputs are constants to a
557// constant. A new constant is created if necessary.
558ConstantFoldingRule FoldOpDotWithConstants() {
559 return [](IRContext* context, Instruction* inst,
560 const std::vector<const analysis::Constant*>& constants)
561 -> const analysis::Constant* {
562 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
563 analysis::TypeManager* type_mgr = context->get_type_mgr();
564 const analysis::Type* new_type = type_mgr->GetType(inst->type_id());
565 assert(new_type->AsFloat() && "OpDot should have a float return type.");
566 const analysis::Float* float_type = new_type->AsFloat();
567
568 if (!inst->IsFloatingPointFoldingAllowed()) {
569 return nullptr;
570 }
571
572 // If one of the operands is 0, then the result is 0.
573 bool has_zero_operand = false;
574
575 for (int i = 0; i < 2; ++i) {
576 if (constants[i]) {
577 if (constants[i]->AsNullConstant() ||
578 constants[i]->AsVectorConstant()->IsZero()) {
579 has_zero_operand = true;
580 break;
581 }
582 }
583 }
584
585 if (has_zero_operand) {
586 if (float_type->width() == 32) {
587 utils::FloatProxy<float> result(0.0f);
588 std::vector<uint32_t> words = result.GetWords();
589 return const_mgr->GetConstant(float_type, words);
590 }
591 if (float_type->width() == 64) {
592 utils::FloatProxy<double> result(0.0);
593 std::vector<uint32_t> words = result.GetWords();
594 return const_mgr->GetConstant(float_type, words);
595 }
596 return nullptr;
597 }
598
599 if (constants[0] == nullptr || constants[1] == nullptr) {
600 return nullptr;
601 }
602
603 std::vector<const analysis::Constant*> a_components;
604 std::vector<const analysis::Constant*> b_components;
605
606 a_components = constants[0]->GetVectorComponents(const_mgr);
607 b_components = constants[1]->GetVectorComponents(const_mgr);
608
609 utils::FloatProxy<double> result(0.0);
610 std::vector<uint32_t> words = result.GetWords();
611 const analysis::Constant* result_const =
612 const_mgr->GetConstant(float_type, words);
Ben Claytonb73b7602019-07-29 13:56:13 +0100613 for (uint32_t i = 0; i < a_components.size() && result_const != nullptr;
614 ++i) {
Chris Forbescc5697f2019-01-30 11:54:08 -0800615 if (a_components[i] == nullptr || b_components[i] == nullptr) {
616 return nullptr;
617 }
618
619 const analysis::Constant* component = FOLD_FPARITH_OP(*)(
620 new_type, a_components[i], b_components[i], const_mgr);
Ben Claytonb73b7602019-07-29 13:56:13 +0100621 if (component == nullptr) {
622 return nullptr;
623 }
Chris Forbescc5697f2019-01-30 11:54:08 -0800624 result_const =
625 FOLD_FPARITH_OP(+)(new_type, result_const, component, const_mgr);
626 }
627 return result_const;
628 };
629}
630
631// This function defines a |UnaryScalarFoldingRule| that subtracts the constant
632// from zero.
633UnaryScalarFoldingRule FoldFNegateOp() {
634 return [](const analysis::Type* result_type, const analysis::Constant* a,
635 analysis::ConstantManager* const_mgr) -> const analysis::Constant* {
636 assert(result_type != nullptr && a != nullptr);
637 assert(result_type == a->type());
638 const analysis::Float* float_type = result_type->AsFloat();
639 assert(float_type != nullptr);
640 if (float_type->width() == 32) {
641 float fa = a->GetFloat();
642 utils::FloatProxy<float> result(-fa);
643 std::vector<uint32_t> words = result.GetWords();
644 return const_mgr->GetConstant(result_type, words);
645 } else if (float_type->width() == 64) {
646 double da = a->GetDouble();
647 utils::FloatProxy<double> result(-da);
648 std::vector<uint32_t> words = result.GetWords();
649 return const_mgr->GetConstant(result_type, words);
650 }
651 return nullptr;
652 };
653}
654
655ConstantFoldingRule FoldFNegate() { return FoldFPUnaryOp(FoldFNegateOp()); }
656
657ConstantFoldingRule FoldFClampFeedingCompare(uint32_t cmp_opcode) {
658 return [cmp_opcode](IRContext* context, Instruction* inst,
659 const std::vector<const analysis::Constant*>& constants)
660 -> const analysis::Constant* {
661 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
662 analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
663
664 if (!inst->IsFloatingPointFoldingAllowed()) {
665 return nullptr;
666 }
667
668 uint32_t non_const_idx = (constants[0] ? 1 : 0);
669 uint32_t operand_id = inst->GetSingleWordInOperand(non_const_idx);
670 Instruction* operand_inst = def_use_mgr->GetDef(operand_id);
671
672 analysis::TypeManager* type_mgr = context->get_type_mgr();
673 const analysis::Type* operand_type =
674 type_mgr->GetType(operand_inst->type_id());
675
676 if (!operand_type->AsFloat()) {
677 return nullptr;
678 }
679
680 if (operand_type->AsFloat()->width() != 32 &&
681 operand_type->AsFloat()->width() != 64) {
682 return nullptr;
683 }
684
685 if (operand_inst->opcode() != SpvOpExtInst) {
686 return nullptr;
687 }
688
689 if (operand_inst->GetSingleWordInOperand(1) != GLSLstd450FClamp) {
690 return nullptr;
691 }
692
693 if (constants[1] == nullptr && constants[0] == nullptr) {
694 return nullptr;
695 }
696
697 uint32_t max_id = operand_inst->GetSingleWordInOperand(4);
698 const analysis::Constant* max_const =
699 const_mgr->FindDeclaredConstant(max_id);
700
701 uint32_t min_id = operand_inst->GetSingleWordInOperand(3);
702 const analysis::Constant* min_const =
703 const_mgr->FindDeclaredConstant(min_id);
704
705 bool found_result = false;
706 bool result = false;
707
708 switch (cmp_opcode) {
709 case SpvOpFOrdLessThan:
710 case SpvOpFUnordLessThan:
711 case SpvOpFOrdGreaterThanEqual:
712 case SpvOpFUnordGreaterThanEqual:
713 if (constants[0]) {
714 if (min_const) {
715 if (constants[0]->GetValueAsDouble() <
716 min_const->GetValueAsDouble()) {
717 found_result = true;
718 result = (cmp_opcode == SpvOpFOrdLessThan ||
719 cmp_opcode == SpvOpFUnordLessThan);
720 }
721 }
722 if (max_const) {
723 if (constants[0]->GetValueAsDouble() >=
724 max_const->GetValueAsDouble()) {
725 found_result = true;
726 result = !(cmp_opcode == SpvOpFOrdLessThan ||
727 cmp_opcode == SpvOpFUnordLessThan);
728 }
729 }
730 }
731
732 if (constants[1]) {
733 if (max_const) {
734 if (max_const->GetValueAsDouble() <
735 constants[1]->GetValueAsDouble()) {
736 found_result = true;
737 result = (cmp_opcode == SpvOpFOrdLessThan ||
738 cmp_opcode == SpvOpFUnordLessThan);
739 }
740 }
741
742 if (min_const) {
743 if (min_const->GetValueAsDouble() >=
744 constants[1]->GetValueAsDouble()) {
745 found_result = true;
746 result = !(cmp_opcode == SpvOpFOrdLessThan ||
747 cmp_opcode == SpvOpFUnordLessThan);
748 }
749 }
750 }
751 break;
752 case SpvOpFOrdGreaterThan:
753 case SpvOpFUnordGreaterThan:
754 case SpvOpFOrdLessThanEqual:
755 case SpvOpFUnordLessThanEqual:
756 if (constants[0]) {
757 if (min_const) {
758 if (constants[0]->GetValueAsDouble() <=
759 min_const->GetValueAsDouble()) {
760 found_result = true;
761 result = (cmp_opcode == SpvOpFOrdLessThanEqual ||
762 cmp_opcode == SpvOpFUnordLessThanEqual);
763 }
764 }
765 if (max_const) {
766 if (constants[0]->GetValueAsDouble() >
767 max_const->GetValueAsDouble()) {
768 found_result = true;
769 result = !(cmp_opcode == SpvOpFOrdLessThanEqual ||
770 cmp_opcode == SpvOpFUnordLessThanEqual);
771 }
772 }
773 }
774
775 if (constants[1]) {
776 if (max_const) {
777 if (max_const->GetValueAsDouble() <=
778 constants[1]->GetValueAsDouble()) {
779 found_result = true;
780 result = (cmp_opcode == SpvOpFOrdLessThanEqual ||
781 cmp_opcode == SpvOpFUnordLessThanEqual);
782 }
783 }
784
785 if (min_const) {
786 if (min_const->GetValueAsDouble() >
787 constants[1]->GetValueAsDouble()) {
788 found_result = true;
789 result = !(cmp_opcode == SpvOpFOrdLessThanEqual ||
790 cmp_opcode == SpvOpFUnordLessThanEqual);
791 }
792 }
793 }
794 break;
795 default:
796 return nullptr;
797 }
798
799 if (!found_result) {
800 return nullptr;
801 }
802
803 const analysis::Type* bool_type =
804 context->get_type_mgr()->GetType(inst->type_id());
805 const analysis::Constant* result_const =
806 const_mgr->GetConstant(bool_type, {static_cast<uint32_t>(result)});
807 assert(result_const);
808 return result_const;
809 };
810}
811
Ben Claytond0f684e2019-08-30 22:36:08 +0100812ConstantFoldingRule FoldFMix() {
813 return [](IRContext* context, Instruction* inst,
814 const std::vector<const analysis::Constant*>& constants)
815 -> const analysis::Constant* {
816 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
817 assert(inst->opcode() == SpvOpExtInst &&
818 "Expecting an extended instruction.");
819 assert(inst->GetSingleWordInOperand(0) ==
820 context->get_feature_mgr()->GetExtInstImportId_GLSLstd450() &&
821 "Expecting a GLSLstd450 extended instruction.");
822 assert(inst->GetSingleWordInOperand(1) == GLSLstd450FMix &&
823 "Expecting and FMix instruction.");
824
825 if (!inst->IsFloatingPointFoldingAllowed()) {
826 return nullptr;
827 }
828
829 // Make sure all FMix operands are constants.
830 for (uint32_t i = 1; i < 4; i++) {
831 if (constants[i] == nullptr) {
832 return nullptr;
833 }
834 }
835
836 const analysis::Constant* one;
837 if (constants[1]->type()->AsFloat()->width() == 32) {
838 one = const_mgr->GetConstant(constants[1]->type(),
839 utils::FloatProxy<float>(1.0f).GetWords());
840 } else {
841 one = const_mgr->GetConstant(constants[1]->type(),
842 utils::FloatProxy<double>(1.0).GetWords());
843 }
844
845 const analysis::Constant* temp1 =
846 FOLD_FPARITH_OP(-)(constants[1]->type(), one, constants[3], const_mgr);
847 if (temp1 == nullptr) {
848 return nullptr;
849 }
850
851 const analysis::Constant* temp2 = FOLD_FPARITH_OP(*)(
852 constants[1]->type(), constants[1], temp1, const_mgr);
853 if (temp2 == nullptr) {
854 return nullptr;
855 }
856 const analysis::Constant* temp3 = FOLD_FPARITH_OP(*)(
857 constants[2]->type(), constants[2], constants[3], const_mgr);
858 if (temp3 == nullptr) {
859 return nullptr;
860 }
861 return FOLD_FPARITH_OP(+)(temp2->type(), temp2, temp3, const_mgr);
862 };
863}
864
Chris Forbescc5697f2019-01-30 11:54:08 -0800865} // namespace
866
Ben Claytond0f684e2019-08-30 22:36:08 +0100867void ConstantFoldingRules::AddFoldingRules() {
Chris Forbescc5697f2019-01-30 11:54:08 -0800868 // Add all folding rules to the list for the opcodes to which they apply.
869 // Note that the order in which rules are added to the list matters. If a rule
870 // applies to the instruction, the rest of the rules will not be attempted.
871 // Take that into consideration.
872
873 rules_[SpvOpCompositeConstruct].push_back(FoldCompositeWithConstants());
874
875 rules_[SpvOpCompositeExtract].push_back(FoldExtractWithConstants());
876
877 rules_[SpvOpConvertFToS].push_back(FoldFToI());
878 rules_[SpvOpConvertFToU].push_back(FoldFToI());
879 rules_[SpvOpConvertSToF].push_back(FoldIToF());
880 rules_[SpvOpConvertUToF].push_back(FoldIToF());
881
882 rules_[SpvOpDot].push_back(FoldOpDotWithConstants());
883 rules_[SpvOpFAdd].push_back(FoldFAdd());
884 rules_[SpvOpFDiv].push_back(FoldFDiv());
885 rules_[SpvOpFMul].push_back(FoldFMul());
886 rules_[SpvOpFSub].push_back(FoldFSub());
887
888 rules_[SpvOpFOrdEqual].push_back(FoldFOrdEqual());
889
890 rules_[SpvOpFUnordEqual].push_back(FoldFUnordEqual());
891
892 rules_[SpvOpFOrdNotEqual].push_back(FoldFOrdNotEqual());
893
894 rules_[SpvOpFUnordNotEqual].push_back(FoldFUnordNotEqual());
895
896 rules_[SpvOpFOrdLessThan].push_back(FoldFOrdLessThan());
897 rules_[SpvOpFOrdLessThan].push_back(
898 FoldFClampFeedingCompare(SpvOpFOrdLessThan));
899
900 rules_[SpvOpFUnordLessThan].push_back(FoldFUnordLessThan());
901 rules_[SpvOpFUnordLessThan].push_back(
902 FoldFClampFeedingCompare(SpvOpFUnordLessThan));
903
904 rules_[SpvOpFOrdGreaterThan].push_back(FoldFOrdGreaterThan());
905 rules_[SpvOpFOrdGreaterThan].push_back(
906 FoldFClampFeedingCompare(SpvOpFOrdGreaterThan));
907
908 rules_[SpvOpFUnordGreaterThan].push_back(FoldFUnordGreaterThan());
909 rules_[SpvOpFUnordGreaterThan].push_back(
910 FoldFClampFeedingCompare(SpvOpFUnordGreaterThan));
911
912 rules_[SpvOpFOrdLessThanEqual].push_back(FoldFOrdLessThanEqual());
913 rules_[SpvOpFOrdLessThanEqual].push_back(
914 FoldFClampFeedingCompare(SpvOpFOrdLessThanEqual));
915
916 rules_[SpvOpFUnordLessThanEqual].push_back(FoldFUnordLessThanEqual());
917 rules_[SpvOpFUnordLessThanEqual].push_back(
918 FoldFClampFeedingCompare(SpvOpFUnordLessThanEqual));
919
920 rules_[SpvOpFOrdGreaterThanEqual].push_back(FoldFOrdGreaterThanEqual());
921 rules_[SpvOpFOrdGreaterThanEqual].push_back(
922 FoldFClampFeedingCompare(SpvOpFOrdGreaterThanEqual));
923
924 rules_[SpvOpFUnordGreaterThanEqual].push_back(FoldFUnordGreaterThanEqual());
925 rules_[SpvOpFUnordGreaterThanEqual].push_back(
926 FoldFClampFeedingCompare(SpvOpFUnordGreaterThanEqual));
927
928 rules_[SpvOpVectorShuffle].push_back(FoldVectorShuffleWithConstants());
929 rules_[SpvOpVectorTimesScalar].push_back(FoldVectorTimesScalar());
930
931 rules_[SpvOpFNegate].push_back(FoldFNegate());
Ben Claytonb73b7602019-07-29 13:56:13 +0100932 rules_[SpvOpQuantizeToF16].push_back(FoldQuantizeToF16());
Ben Claytond0f684e2019-08-30 22:36:08 +0100933
934 // Add rules for GLSLstd450
935 FeatureManager* feature_manager = context_->get_feature_mgr();
936 uint32_t ext_inst_glslstd450_id =
937 feature_manager->GetExtInstImportId_GLSLstd450();
938 if (ext_inst_glslstd450_id != 0) {
939 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450FMix}].push_back(FoldFMix());
940 }
Chris Forbescc5697f2019-01-30 11:54:08 -0800941}
942} // namespace opt
943} // namespace spvtools