blob: 10fcde40803d1e85cc00d0fce1924e2eabe782ab [file] [log] [blame]
Chris Forbescc5697f2019-01-30 11:54:08 -08001// Copyright (c) 2018 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#include "source/opt/const_folding_rules.h"
16
17#include "source/opt/ir_context.h"
18
19namespace spvtools {
20namespace opt {
21namespace {
22
23const uint32_t kExtractCompositeIdInIdx = 0;
24
25// Returns true if |type| is Float or a vector of Float.
26bool HasFloatingPoint(const analysis::Type* type) {
27 if (type->AsFloat()) {
28 return true;
29 } else if (const analysis::Vector* vec_type = type->AsVector()) {
30 return vec_type->element_type()->AsFloat() != nullptr;
31 }
32
33 return false;
34}
35
36// Folds an OpcompositeExtract where input is a composite constant.
37ConstantFoldingRule FoldExtractWithConstants() {
38 return [](IRContext* context, Instruction* inst,
39 const std::vector<const analysis::Constant*>& constants)
40 -> const analysis::Constant* {
41 const analysis::Constant* c = constants[kExtractCompositeIdInIdx];
42 if (c == nullptr) {
43 return nullptr;
44 }
45
46 for (uint32_t i = 1; i < inst->NumInOperands(); ++i) {
47 uint32_t element_index = inst->GetSingleWordInOperand(i);
48 if (c->AsNullConstant()) {
49 // Return Null for the return type.
50 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
51 analysis::TypeManager* type_mgr = context->get_type_mgr();
52 return const_mgr->GetConstant(type_mgr->GetType(inst->type_id()), {});
53 }
54
55 auto cc = c->AsCompositeConstant();
56 assert(cc != nullptr);
57 auto components = cc->GetComponents();
58 c = components[element_index];
59 }
60 return c;
61 };
62}
63
64ConstantFoldingRule FoldVectorShuffleWithConstants() {
65 return [](IRContext* context, Instruction* inst,
66 const std::vector<const analysis::Constant*>& constants)
67 -> const analysis::Constant* {
68 assert(inst->opcode() == SpvOpVectorShuffle);
69 const analysis::Constant* c1 = constants[0];
70 const analysis::Constant* c2 = constants[1];
71 if (c1 == nullptr || c2 == nullptr) {
72 return nullptr;
73 }
74
75 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
76 const analysis::Type* element_type = c1->type()->AsVector()->element_type();
77
78 std::vector<const analysis::Constant*> c1_components;
79 if (const analysis::VectorConstant* vec_const = c1->AsVectorConstant()) {
80 c1_components = vec_const->GetComponents();
81 } else {
82 assert(c1->AsNullConstant());
83 const analysis::Constant* element =
84 const_mgr->GetConstant(element_type, {});
85 c1_components.resize(c1->type()->AsVector()->element_count(), element);
86 }
87 std::vector<const analysis::Constant*> c2_components;
88 if (const analysis::VectorConstant* vec_const = c2->AsVectorConstant()) {
89 c2_components = vec_const->GetComponents();
90 } else {
91 assert(c2->AsNullConstant());
92 const analysis::Constant* element =
93 const_mgr->GetConstant(element_type, {});
94 c2_components.resize(c2->type()->AsVector()->element_count(), element);
95 }
96
97 std::vector<uint32_t> ids;
98 const uint32_t undef_literal_value = 0xffffffff;
99 for (uint32_t i = 2; i < inst->NumInOperands(); ++i) {
100 uint32_t index = inst->GetSingleWordInOperand(i);
101 if (index == undef_literal_value) {
102 // Don't fold shuffle with undef literal value.
103 return nullptr;
104 } else if (index < c1_components.size()) {
105 Instruction* member_inst =
106 const_mgr->GetDefiningInstruction(c1_components[index]);
107 ids.push_back(member_inst->result_id());
108 } else {
109 Instruction* member_inst = const_mgr->GetDefiningInstruction(
110 c2_components[index - c1_components.size()]);
111 ids.push_back(member_inst->result_id());
112 }
113 }
114
115 analysis::TypeManager* type_mgr = context->get_type_mgr();
116 return const_mgr->GetConstant(type_mgr->GetType(inst->type_id()), ids);
117 };
118}
119
120ConstantFoldingRule FoldVectorTimesScalar() {
121 return [](IRContext* context, Instruction* inst,
122 const std::vector<const analysis::Constant*>& constants)
123 -> const analysis::Constant* {
124 assert(inst->opcode() == SpvOpVectorTimesScalar);
125 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
126 analysis::TypeManager* type_mgr = context->get_type_mgr();
127
128 if (!inst->IsFloatingPointFoldingAllowed()) {
129 if (HasFloatingPoint(type_mgr->GetType(inst->type_id()))) {
130 return nullptr;
131 }
132 }
133
134 const analysis::Constant* c1 = constants[0];
135 const analysis::Constant* c2 = constants[1];
136
137 if (c1 && c1->IsZero()) {
138 return c1;
139 }
140
141 if (c2 && c2->IsZero()) {
142 // Get or create the NullConstant for this type.
143 std::vector<uint32_t> ids;
144 return const_mgr->GetConstant(type_mgr->GetType(inst->type_id()), ids);
145 }
146
147 if (c1 == nullptr || c2 == nullptr) {
148 return nullptr;
149 }
150
151 // Check result type.
152 const analysis::Type* result_type = type_mgr->GetType(inst->type_id());
153 const analysis::Vector* vector_type = result_type->AsVector();
154 assert(vector_type != nullptr);
155 const analysis::Type* element_type = vector_type->element_type();
156 assert(element_type != nullptr);
157 const analysis::Float* float_type = element_type->AsFloat();
158 assert(float_type != nullptr);
159
160 // Check types of c1 and c2.
161 assert(c1->type()->AsVector() == vector_type);
162 assert(c1->type()->AsVector()->element_type() == element_type &&
163 c2->type() == element_type);
164
165 // Get a float vector that is the result of vector-times-scalar.
166 std::vector<const analysis::Constant*> c1_components =
167 c1->GetVectorComponents(const_mgr);
168 std::vector<uint32_t> ids;
169 if (float_type->width() == 32) {
170 float scalar = c2->GetFloat();
171 for (uint32_t i = 0; i < c1_components.size(); ++i) {
172 utils::FloatProxy<float> result(c1_components[i]->GetFloat() * scalar);
173 std::vector<uint32_t> words = result.GetWords();
174 const analysis::Constant* new_elem =
175 const_mgr->GetConstant(float_type, words);
176 ids.push_back(const_mgr->GetDefiningInstruction(new_elem)->result_id());
177 }
178 return const_mgr->GetConstant(vector_type, ids);
179 } else if (float_type->width() == 64) {
180 double scalar = c2->GetDouble();
181 for (uint32_t i = 0; i < c1_components.size(); ++i) {
182 utils::FloatProxy<double> result(c1_components[i]->GetDouble() *
183 scalar);
184 std::vector<uint32_t> words = result.GetWords();
185 const analysis::Constant* new_elem =
186 const_mgr->GetConstant(float_type, words);
187 ids.push_back(const_mgr->GetDefiningInstruction(new_elem)->result_id());
188 }
189 return const_mgr->GetConstant(vector_type, ids);
190 }
191 return nullptr;
192 };
193}
194
195ConstantFoldingRule FoldCompositeWithConstants() {
196 // Folds an OpCompositeConstruct where all of the inputs are constants to a
197 // constant. A new constant is created if necessary.
198 return [](IRContext* context, Instruction* inst,
199 const std::vector<const analysis::Constant*>& constants)
200 -> const analysis::Constant* {
201 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
202 analysis::TypeManager* type_mgr = context->get_type_mgr();
203 const analysis::Type* new_type = type_mgr->GetType(inst->type_id());
204 Instruction* type_inst =
205 context->get_def_use_mgr()->GetDef(inst->type_id());
206
207 std::vector<uint32_t> ids;
208 for (uint32_t i = 0; i < constants.size(); ++i) {
209 const analysis::Constant* element_const = constants[i];
210 if (element_const == nullptr) {
211 return nullptr;
212 }
213
214 uint32_t component_type_id = 0;
215 if (type_inst->opcode() == SpvOpTypeStruct) {
216 component_type_id = type_inst->GetSingleWordInOperand(i);
217 } else if (type_inst->opcode() == SpvOpTypeArray) {
218 component_type_id = type_inst->GetSingleWordInOperand(0);
219 }
220
221 uint32_t element_id =
222 const_mgr->FindDeclaredConstant(element_const, component_type_id);
223 if (element_id == 0) {
224 return nullptr;
225 }
226 ids.push_back(element_id);
227 }
228 return const_mgr->GetConstant(new_type, ids);
229 };
230}
231
232// The interface for a function that returns the result of applying a scalar
233// floating-point binary operation on |a| and |b|. The type of the return value
234// will be |type|. The input constants must also be of type |type|.
235using UnaryScalarFoldingRule = std::function<const analysis::Constant*(
236 const analysis::Type* result_type, const analysis::Constant* a,
237 analysis::ConstantManager*)>;
238
239// The interface for a function that returns the result of applying a scalar
240// floating-point binary operation on |a| and |b|. The type of the return value
241// will be |type|. The input constants must also be of type |type|.
242using BinaryScalarFoldingRule = std::function<const analysis::Constant*(
243 const analysis::Type* result_type, const analysis::Constant* a,
244 const analysis::Constant* b, analysis::ConstantManager*)>;
245
246// Returns a |ConstantFoldingRule| that folds unary floating point scalar ops
247// using |scalar_rule| and unary float point vectors ops by applying
248// |scalar_rule| to the elements of the vector. The |ConstantFoldingRule|
249// that is returned assumes that |constants| contains 1 entry. If they are
250// not |nullptr|, then their type is either |Float| or |Integer| or a |Vector|
251// whose element type is |Float| or |Integer|.
252ConstantFoldingRule FoldFPUnaryOp(UnaryScalarFoldingRule scalar_rule) {
253 return [scalar_rule](IRContext* context, Instruction* inst,
254 const std::vector<const analysis::Constant*>& constants)
255 -> const analysis::Constant* {
256 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
257 analysis::TypeManager* type_mgr = context->get_type_mgr();
258 const analysis::Type* result_type = type_mgr->GetType(inst->type_id());
259 const analysis::Vector* vector_type = result_type->AsVector();
260
261 if (!inst->IsFloatingPointFoldingAllowed()) {
262 return nullptr;
263 }
264
265 if (constants[0] == nullptr) {
266 return nullptr;
267 }
268
269 if (vector_type != nullptr) {
270 std::vector<const analysis::Constant*> a_components;
271 std::vector<const analysis::Constant*> results_components;
272
273 a_components = constants[0]->GetVectorComponents(const_mgr);
274
275 // Fold each component of the vector.
276 for (uint32_t i = 0; i < a_components.size(); ++i) {
277 results_components.push_back(scalar_rule(vector_type->element_type(),
278 a_components[i], const_mgr));
279 if (results_components[i] == nullptr) {
280 return nullptr;
281 }
282 }
283
284 // Build the constant object and return it.
285 std::vector<uint32_t> ids;
286 for (const analysis::Constant* member : results_components) {
287 ids.push_back(const_mgr->GetDefiningInstruction(member)->result_id());
288 }
289 return const_mgr->GetConstant(vector_type, ids);
290 } else {
291 return scalar_rule(result_type, constants[0], const_mgr);
292 }
293 };
294}
295
296// Returns a |ConstantFoldingRule| that folds floating point scalars using
297// |scalar_rule| and vectors of floating point by applying |scalar_rule| to the
298// elements of the vector. The |ConstantFoldingRule| that is returned assumes
299// that |constants| contains 2 entries. If they are not |nullptr|, then their
300// type is either |Float| or a |Vector| whose element type is |Float|.
301ConstantFoldingRule FoldFPBinaryOp(BinaryScalarFoldingRule scalar_rule) {
302 return [scalar_rule](IRContext* context, Instruction* inst,
303 const std::vector<const analysis::Constant*>& constants)
304 -> const analysis::Constant* {
305 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
306 analysis::TypeManager* type_mgr = context->get_type_mgr();
307 const analysis::Type* result_type = type_mgr->GetType(inst->type_id());
308 const analysis::Vector* vector_type = result_type->AsVector();
309
310 if (!inst->IsFloatingPointFoldingAllowed()) {
311 return nullptr;
312 }
313
314 if (constants[0] == nullptr || constants[1] == nullptr) {
315 return nullptr;
316 }
317
318 if (vector_type != nullptr) {
319 std::vector<const analysis::Constant*> a_components;
320 std::vector<const analysis::Constant*> b_components;
321 std::vector<const analysis::Constant*> results_components;
322
323 a_components = constants[0]->GetVectorComponents(const_mgr);
324 b_components = constants[1]->GetVectorComponents(const_mgr);
325
326 // Fold each component of the vector.
327 for (uint32_t i = 0; i < a_components.size(); ++i) {
328 results_components.push_back(scalar_rule(vector_type->element_type(),
329 a_components[i],
330 b_components[i], const_mgr));
331 if (results_components[i] == nullptr) {
332 return nullptr;
333 }
334 }
335
336 // Build the constant object and return it.
337 std::vector<uint32_t> ids;
338 for (const analysis::Constant* member : results_components) {
339 ids.push_back(const_mgr->GetDefiningInstruction(member)->result_id());
340 }
341 return const_mgr->GetConstant(vector_type, ids);
342 } else {
343 return scalar_rule(result_type, constants[0], constants[1], const_mgr);
344 }
345 };
346}
347
348// This macro defines a |UnaryScalarFoldingRule| that performs float to
349// integer conversion.
350// TODO(greg-lunarg): Support for 64-bit integer types.
351UnaryScalarFoldingRule FoldFToIOp() {
352 return [](const analysis::Type* result_type, const analysis::Constant* a,
353 analysis::ConstantManager* const_mgr) -> const analysis::Constant* {
354 assert(result_type != nullptr && a != nullptr);
355 const analysis::Integer* integer_type = result_type->AsInteger();
356 const analysis::Float* float_type = a->type()->AsFloat();
357 assert(float_type != nullptr);
358 assert(integer_type != nullptr);
359 if (integer_type->width() != 32) return nullptr;
360 if (float_type->width() == 32) {
361 float fa = a->GetFloat();
362 uint32_t result = integer_type->IsSigned()
363 ? static_cast<uint32_t>(static_cast<int32_t>(fa))
364 : static_cast<uint32_t>(fa);
365 std::vector<uint32_t> words = {result};
366 return const_mgr->GetConstant(result_type, words);
367 } else if (float_type->width() == 64) {
368 double fa = a->GetDouble();
369 uint32_t result = integer_type->IsSigned()
370 ? static_cast<uint32_t>(static_cast<int32_t>(fa))
371 : static_cast<uint32_t>(fa);
372 std::vector<uint32_t> words = {result};
373 return const_mgr->GetConstant(result_type, words);
374 }
375 return nullptr;
376 };
377}
378
379// This function defines a |UnaryScalarFoldingRule| that performs integer to
380// float conversion.
381// TODO(greg-lunarg): Support for 64-bit integer types.
382UnaryScalarFoldingRule FoldIToFOp() {
383 return [](const analysis::Type* result_type, const analysis::Constant* a,
384 analysis::ConstantManager* const_mgr) -> const analysis::Constant* {
385 assert(result_type != nullptr && a != nullptr);
386 const analysis::Integer* integer_type = a->type()->AsInteger();
387 const analysis::Float* float_type = result_type->AsFloat();
388 assert(float_type != nullptr);
389 assert(integer_type != nullptr);
390 if (integer_type->width() != 32) return nullptr;
391 uint32_t ua = a->GetU32();
392 if (float_type->width() == 32) {
393 float result_val = integer_type->IsSigned()
394 ? static_cast<float>(static_cast<int32_t>(ua))
395 : static_cast<float>(ua);
396 utils::FloatProxy<float> result(result_val);
397 std::vector<uint32_t> words = {result.data()};
398 return const_mgr->GetConstant(result_type, words);
399 } else if (float_type->width() == 64) {
400 double result_val = integer_type->IsSigned()
401 ? static_cast<double>(static_cast<int32_t>(ua))
402 : static_cast<double>(ua);
403 utils::FloatProxy<double> result(result_val);
404 std::vector<uint32_t> words = result.GetWords();
405 return const_mgr->GetConstant(result_type, words);
406 }
407 return nullptr;
408 };
409}
410
Ben Claytonb73b7602019-07-29 13:56:13 +0100411// This defines a |UnaryScalarFoldingRule| that performs |OpQuantizeToF16|.
412UnaryScalarFoldingRule FoldQuantizeToF16Scalar() {
413 return [](const analysis::Type* result_type, const analysis::Constant* a,
414 analysis::ConstantManager* const_mgr) -> const analysis::Constant* {
415 assert(result_type != nullptr && a != nullptr);
416 const analysis::Float* float_type = a->type()->AsFloat();
417 assert(float_type != nullptr);
418 if (float_type->width() != 32) {
419 return nullptr;
420 }
421
422 float fa = a->GetFloat();
423 utils::HexFloat<utils::FloatProxy<float>> orignal(fa);
424 utils::HexFloat<utils::FloatProxy<utils::Float16>> quantized(0);
425 utils::HexFloat<utils::FloatProxy<float>> result(0.0f);
426 orignal.castTo(quantized, utils::round_direction::kToZero);
427 quantized.castTo(result, utils::round_direction::kToZero);
428 std::vector<uint32_t> words = {result.getBits()};
429 return const_mgr->GetConstant(result_type, words);
430 };
431}
432
Chris Forbescc5697f2019-01-30 11:54:08 -0800433// This macro defines a |BinaryScalarFoldingRule| that applies |op|. The
434// operator |op| must work for both float and double, and use syntax "f1 op f2".
435#define FOLD_FPARITH_OP(op) \
436 [](const analysis::Type* result_type, const analysis::Constant* a, \
437 const analysis::Constant* b, \
438 analysis::ConstantManager* const_mgr_in_macro) \
439 -> const analysis::Constant* { \
440 assert(result_type != nullptr && a != nullptr && b != nullptr); \
441 assert(result_type == a->type() && result_type == b->type()); \
442 const analysis::Float* float_type_in_macro = result_type->AsFloat(); \
443 assert(float_type_in_macro != nullptr); \
444 if (float_type_in_macro->width() == 32) { \
445 float fa = a->GetFloat(); \
446 float fb = b->GetFloat(); \
447 utils::FloatProxy<float> result_in_macro(fa op fb); \
448 std::vector<uint32_t> words_in_macro = result_in_macro.GetWords(); \
449 return const_mgr_in_macro->GetConstant(result_type, words_in_macro); \
450 } else if (float_type_in_macro->width() == 64) { \
451 double fa = a->GetDouble(); \
452 double fb = b->GetDouble(); \
453 utils::FloatProxy<double> result_in_macro(fa op fb); \
454 std::vector<uint32_t> words_in_macro = result_in_macro.GetWords(); \
455 return const_mgr_in_macro->GetConstant(result_type, words_in_macro); \
456 } \
457 return nullptr; \
458 }
459
460// Define the folding rule for conversion between floating point and integer
461ConstantFoldingRule FoldFToI() { return FoldFPUnaryOp(FoldFToIOp()); }
462ConstantFoldingRule FoldIToF() { return FoldFPUnaryOp(FoldIToFOp()); }
Ben Claytonb73b7602019-07-29 13:56:13 +0100463ConstantFoldingRule FoldQuantizeToF16() {
464 return FoldFPUnaryOp(FoldQuantizeToF16Scalar());
465}
Chris Forbescc5697f2019-01-30 11:54:08 -0800466
467// Define the folding rules for subtraction, addition, multiplication, and
468// division for floating point values.
469ConstantFoldingRule FoldFSub() { return FoldFPBinaryOp(FOLD_FPARITH_OP(-)); }
470ConstantFoldingRule FoldFAdd() { return FoldFPBinaryOp(FOLD_FPARITH_OP(+)); }
471ConstantFoldingRule FoldFMul() { return FoldFPBinaryOp(FOLD_FPARITH_OP(*)); }
472ConstantFoldingRule FoldFDiv() { return FoldFPBinaryOp(FOLD_FPARITH_OP(/)); }
473
474bool CompareFloatingPoint(bool op_result, bool op_unordered,
475 bool need_ordered) {
476 if (need_ordered) {
477 // operands are ordered and Operand 1 is |op| Operand 2
478 return !op_unordered && op_result;
479 } else {
480 // operands are unordered or Operand 1 is |op| Operand 2
481 return op_unordered || op_result;
482 }
483}
484
485// This macro defines a |BinaryScalarFoldingRule| that applies |op|. The
486// operator |op| must work for both float and double, and use syntax "f1 op f2".
487#define FOLD_FPCMP_OP(op, ord) \
488 [](const analysis::Type* result_type, const analysis::Constant* a, \
489 const analysis::Constant* b, \
490 analysis::ConstantManager* const_mgr) -> const analysis::Constant* { \
491 assert(result_type != nullptr && a != nullptr && b != nullptr); \
492 assert(result_type->AsBool()); \
493 assert(a->type() == b->type()); \
494 const analysis::Float* float_type = a->type()->AsFloat(); \
495 assert(float_type != nullptr); \
496 if (float_type->width() == 32) { \
497 float fa = a->GetFloat(); \
498 float fb = b->GetFloat(); \
499 bool result = CompareFloatingPoint( \
500 fa op fb, std::isnan(fa) || std::isnan(fb), ord); \
501 std::vector<uint32_t> words = {uint32_t(result)}; \
502 return const_mgr->GetConstant(result_type, words); \
503 } else if (float_type->width() == 64) { \
504 double fa = a->GetDouble(); \
505 double fb = b->GetDouble(); \
506 bool result = CompareFloatingPoint( \
507 fa op fb, std::isnan(fa) || std::isnan(fb), ord); \
508 std::vector<uint32_t> words = {uint32_t(result)}; \
509 return const_mgr->GetConstant(result_type, words); \
510 } \
511 return nullptr; \
512 }
513
514// Define the folding rules for ordered and unordered comparison for floating
515// point values.
516ConstantFoldingRule FoldFOrdEqual() {
517 return FoldFPBinaryOp(FOLD_FPCMP_OP(==, true));
518}
519ConstantFoldingRule FoldFUnordEqual() {
520 return FoldFPBinaryOp(FOLD_FPCMP_OP(==, false));
521}
522ConstantFoldingRule FoldFOrdNotEqual() {
523 return FoldFPBinaryOp(FOLD_FPCMP_OP(!=, true));
524}
525ConstantFoldingRule FoldFUnordNotEqual() {
526 return FoldFPBinaryOp(FOLD_FPCMP_OP(!=, false));
527}
528ConstantFoldingRule FoldFOrdLessThan() {
529 return FoldFPBinaryOp(FOLD_FPCMP_OP(<, true));
530}
531ConstantFoldingRule FoldFUnordLessThan() {
532 return FoldFPBinaryOp(FOLD_FPCMP_OP(<, false));
533}
534ConstantFoldingRule FoldFOrdGreaterThan() {
535 return FoldFPBinaryOp(FOLD_FPCMP_OP(>, true));
536}
537ConstantFoldingRule FoldFUnordGreaterThan() {
538 return FoldFPBinaryOp(FOLD_FPCMP_OP(>, false));
539}
540ConstantFoldingRule FoldFOrdLessThanEqual() {
541 return FoldFPBinaryOp(FOLD_FPCMP_OP(<=, true));
542}
543ConstantFoldingRule FoldFUnordLessThanEqual() {
544 return FoldFPBinaryOp(FOLD_FPCMP_OP(<=, false));
545}
546ConstantFoldingRule FoldFOrdGreaterThanEqual() {
547 return FoldFPBinaryOp(FOLD_FPCMP_OP(>=, true));
548}
549ConstantFoldingRule FoldFUnordGreaterThanEqual() {
550 return FoldFPBinaryOp(FOLD_FPCMP_OP(>=, false));
551}
552
553// Folds an OpDot where all of the inputs are constants to a
554// constant. A new constant is created if necessary.
555ConstantFoldingRule FoldOpDotWithConstants() {
556 return [](IRContext* context, Instruction* inst,
557 const std::vector<const analysis::Constant*>& constants)
558 -> const analysis::Constant* {
559 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
560 analysis::TypeManager* type_mgr = context->get_type_mgr();
561 const analysis::Type* new_type = type_mgr->GetType(inst->type_id());
562 assert(new_type->AsFloat() && "OpDot should have a float return type.");
563 const analysis::Float* float_type = new_type->AsFloat();
564
565 if (!inst->IsFloatingPointFoldingAllowed()) {
566 return nullptr;
567 }
568
569 // If one of the operands is 0, then the result is 0.
570 bool has_zero_operand = false;
571
572 for (int i = 0; i < 2; ++i) {
573 if (constants[i]) {
574 if (constants[i]->AsNullConstant() ||
575 constants[i]->AsVectorConstant()->IsZero()) {
576 has_zero_operand = true;
577 break;
578 }
579 }
580 }
581
582 if (has_zero_operand) {
583 if (float_type->width() == 32) {
584 utils::FloatProxy<float> result(0.0f);
585 std::vector<uint32_t> words = result.GetWords();
586 return const_mgr->GetConstant(float_type, words);
587 }
588 if (float_type->width() == 64) {
589 utils::FloatProxy<double> result(0.0);
590 std::vector<uint32_t> words = result.GetWords();
591 return const_mgr->GetConstant(float_type, words);
592 }
593 return nullptr;
594 }
595
596 if (constants[0] == nullptr || constants[1] == nullptr) {
597 return nullptr;
598 }
599
600 std::vector<const analysis::Constant*> a_components;
601 std::vector<const analysis::Constant*> b_components;
602
603 a_components = constants[0]->GetVectorComponents(const_mgr);
604 b_components = constants[1]->GetVectorComponents(const_mgr);
605
606 utils::FloatProxy<double> result(0.0);
607 std::vector<uint32_t> words = result.GetWords();
608 const analysis::Constant* result_const =
609 const_mgr->GetConstant(float_type, words);
Ben Claytonb73b7602019-07-29 13:56:13 +0100610 for (uint32_t i = 0; i < a_components.size() && result_const != nullptr;
611 ++i) {
Chris Forbescc5697f2019-01-30 11:54:08 -0800612 if (a_components[i] == nullptr || b_components[i] == nullptr) {
613 return nullptr;
614 }
615
616 const analysis::Constant* component = FOLD_FPARITH_OP(*)(
617 new_type, a_components[i], b_components[i], const_mgr);
Ben Claytonb73b7602019-07-29 13:56:13 +0100618 if (component == nullptr) {
619 return nullptr;
620 }
Chris Forbescc5697f2019-01-30 11:54:08 -0800621 result_const =
622 FOLD_FPARITH_OP(+)(new_type, result_const, component, const_mgr);
623 }
624 return result_const;
625 };
626}
627
628// This function defines a |UnaryScalarFoldingRule| that subtracts the constant
629// from zero.
630UnaryScalarFoldingRule FoldFNegateOp() {
631 return [](const analysis::Type* result_type, const analysis::Constant* a,
632 analysis::ConstantManager* const_mgr) -> const analysis::Constant* {
633 assert(result_type != nullptr && a != nullptr);
634 assert(result_type == a->type());
635 const analysis::Float* float_type = result_type->AsFloat();
636 assert(float_type != nullptr);
637 if (float_type->width() == 32) {
638 float fa = a->GetFloat();
639 utils::FloatProxy<float> result(-fa);
640 std::vector<uint32_t> words = result.GetWords();
641 return const_mgr->GetConstant(result_type, words);
642 } else if (float_type->width() == 64) {
643 double da = a->GetDouble();
644 utils::FloatProxy<double> result(-da);
645 std::vector<uint32_t> words = result.GetWords();
646 return const_mgr->GetConstant(result_type, words);
647 }
648 return nullptr;
649 };
650}
651
652ConstantFoldingRule FoldFNegate() { return FoldFPUnaryOp(FoldFNegateOp()); }
653
654ConstantFoldingRule FoldFClampFeedingCompare(uint32_t cmp_opcode) {
655 return [cmp_opcode](IRContext* context, Instruction* inst,
656 const std::vector<const analysis::Constant*>& constants)
657 -> const analysis::Constant* {
658 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
659 analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
660
661 if (!inst->IsFloatingPointFoldingAllowed()) {
662 return nullptr;
663 }
664
665 uint32_t non_const_idx = (constants[0] ? 1 : 0);
666 uint32_t operand_id = inst->GetSingleWordInOperand(non_const_idx);
667 Instruction* operand_inst = def_use_mgr->GetDef(operand_id);
668
669 analysis::TypeManager* type_mgr = context->get_type_mgr();
670 const analysis::Type* operand_type =
671 type_mgr->GetType(operand_inst->type_id());
672
673 if (!operand_type->AsFloat()) {
674 return nullptr;
675 }
676
677 if (operand_type->AsFloat()->width() != 32 &&
678 operand_type->AsFloat()->width() != 64) {
679 return nullptr;
680 }
681
682 if (operand_inst->opcode() != SpvOpExtInst) {
683 return nullptr;
684 }
685
686 if (operand_inst->GetSingleWordInOperand(1) != GLSLstd450FClamp) {
687 return nullptr;
688 }
689
690 if (constants[1] == nullptr && constants[0] == nullptr) {
691 return nullptr;
692 }
693
694 uint32_t max_id = operand_inst->GetSingleWordInOperand(4);
695 const analysis::Constant* max_const =
696 const_mgr->FindDeclaredConstant(max_id);
697
698 uint32_t min_id = operand_inst->GetSingleWordInOperand(3);
699 const analysis::Constant* min_const =
700 const_mgr->FindDeclaredConstant(min_id);
701
702 bool found_result = false;
703 bool result = false;
704
705 switch (cmp_opcode) {
706 case SpvOpFOrdLessThan:
707 case SpvOpFUnordLessThan:
708 case SpvOpFOrdGreaterThanEqual:
709 case SpvOpFUnordGreaterThanEqual:
710 if (constants[0]) {
711 if (min_const) {
712 if (constants[0]->GetValueAsDouble() <
713 min_const->GetValueAsDouble()) {
714 found_result = true;
715 result = (cmp_opcode == SpvOpFOrdLessThan ||
716 cmp_opcode == SpvOpFUnordLessThan);
717 }
718 }
719 if (max_const) {
720 if (constants[0]->GetValueAsDouble() >=
721 max_const->GetValueAsDouble()) {
722 found_result = true;
723 result = !(cmp_opcode == SpvOpFOrdLessThan ||
724 cmp_opcode == SpvOpFUnordLessThan);
725 }
726 }
727 }
728
729 if (constants[1]) {
730 if (max_const) {
731 if (max_const->GetValueAsDouble() <
732 constants[1]->GetValueAsDouble()) {
733 found_result = true;
734 result = (cmp_opcode == SpvOpFOrdLessThan ||
735 cmp_opcode == SpvOpFUnordLessThan);
736 }
737 }
738
739 if (min_const) {
740 if (min_const->GetValueAsDouble() >=
741 constants[1]->GetValueAsDouble()) {
742 found_result = true;
743 result = !(cmp_opcode == SpvOpFOrdLessThan ||
744 cmp_opcode == SpvOpFUnordLessThan);
745 }
746 }
747 }
748 break;
749 case SpvOpFOrdGreaterThan:
750 case SpvOpFUnordGreaterThan:
751 case SpvOpFOrdLessThanEqual:
752 case SpvOpFUnordLessThanEqual:
753 if (constants[0]) {
754 if (min_const) {
755 if (constants[0]->GetValueAsDouble() <=
756 min_const->GetValueAsDouble()) {
757 found_result = true;
758 result = (cmp_opcode == SpvOpFOrdLessThanEqual ||
759 cmp_opcode == SpvOpFUnordLessThanEqual);
760 }
761 }
762 if (max_const) {
763 if (constants[0]->GetValueAsDouble() >
764 max_const->GetValueAsDouble()) {
765 found_result = true;
766 result = !(cmp_opcode == SpvOpFOrdLessThanEqual ||
767 cmp_opcode == SpvOpFUnordLessThanEqual);
768 }
769 }
770 }
771
772 if (constants[1]) {
773 if (max_const) {
774 if (max_const->GetValueAsDouble() <=
775 constants[1]->GetValueAsDouble()) {
776 found_result = true;
777 result = (cmp_opcode == SpvOpFOrdLessThanEqual ||
778 cmp_opcode == SpvOpFUnordLessThanEqual);
779 }
780 }
781
782 if (min_const) {
783 if (min_const->GetValueAsDouble() >
784 constants[1]->GetValueAsDouble()) {
785 found_result = true;
786 result = !(cmp_opcode == SpvOpFOrdLessThanEqual ||
787 cmp_opcode == SpvOpFUnordLessThanEqual);
788 }
789 }
790 }
791 break;
792 default:
793 return nullptr;
794 }
795
796 if (!found_result) {
797 return nullptr;
798 }
799
800 const analysis::Type* bool_type =
801 context->get_type_mgr()->GetType(inst->type_id());
802 const analysis::Constant* result_const =
803 const_mgr->GetConstant(bool_type, {static_cast<uint32_t>(result)});
804 assert(result_const);
805 return result_const;
806 };
807}
808
809} // namespace
810
811ConstantFoldingRules::ConstantFoldingRules() {
812 // Add all folding rules to the list for the opcodes to which they apply.
813 // Note that the order in which rules are added to the list matters. If a rule
814 // applies to the instruction, the rest of the rules will not be attempted.
815 // Take that into consideration.
816
817 rules_[SpvOpCompositeConstruct].push_back(FoldCompositeWithConstants());
818
819 rules_[SpvOpCompositeExtract].push_back(FoldExtractWithConstants());
820
821 rules_[SpvOpConvertFToS].push_back(FoldFToI());
822 rules_[SpvOpConvertFToU].push_back(FoldFToI());
823 rules_[SpvOpConvertSToF].push_back(FoldIToF());
824 rules_[SpvOpConvertUToF].push_back(FoldIToF());
825
826 rules_[SpvOpDot].push_back(FoldOpDotWithConstants());
827 rules_[SpvOpFAdd].push_back(FoldFAdd());
828 rules_[SpvOpFDiv].push_back(FoldFDiv());
829 rules_[SpvOpFMul].push_back(FoldFMul());
830 rules_[SpvOpFSub].push_back(FoldFSub());
831
832 rules_[SpvOpFOrdEqual].push_back(FoldFOrdEqual());
833
834 rules_[SpvOpFUnordEqual].push_back(FoldFUnordEqual());
835
836 rules_[SpvOpFOrdNotEqual].push_back(FoldFOrdNotEqual());
837
838 rules_[SpvOpFUnordNotEqual].push_back(FoldFUnordNotEqual());
839
840 rules_[SpvOpFOrdLessThan].push_back(FoldFOrdLessThan());
841 rules_[SpvOpFOrdLessThan].push_back(
842 FoldFClampFeedingCompare(SpvOpFOrdLessThan));
843
844 rules_[SpvOpFUnordLessThan].push_back(FoldFUnordLessThan());
845 rules_[SpvOpFUnordLessThan].push_back(
846 FoldFClampFeedingCompare(SpvOpFUnordLessThan));
847
848 rules_[SpvOpFOrdGreaterThan].push_back(FoldFOrdGreaterThan());
849 rules_[SpvOpFOrdGreaterThan].push_back(
850 FoldFClampFeedingCompare(SpvOpFOrdGreaterThan));
851
852 rules_[SpvOpFUnordGreaterThan].push_back(FoldFUnordGreaterThan());
853 rules_[SpvOpFUnordGreaterThan].push_back(
854 FoldFClampFeedingCompare(SpvOpFUnordGreaterThan));
855
856 rules_[SpvOpFOrdLessThanEqual].push_back(FoldFOrdLessThanEqual());
857 rules_[SpvOpFOrdLessThanEqual].push_back(
858 FoldFClampFeedingCompare(SpvOpFOrdLessThanEqual));
859
860 rules_[SpvOpFUnordLessThanEqual].push_back(FoldFUnordLessThanEqual());
861 rules_[SpvOpFUnordLessThanEqual].push_back(
862 FoldFClampFeedingCompare(SpvOpFUnordLessThanEqual));
863
864 rules_[SpvOpFOrdGreaterThanEqual].push_back(FoldFOrdGreaterThanEqual());
865 rules_[SpvOpFOrdGreaterThanEqual].push_back(
866 FoldFClampFeedingCompare(SpvOpFOrdGreaterThanEqual));
867
868 rules_[SpvOpFUnordGreaterThanEqual].push_back(FoldFUnordGreaterThanEqual());
869 rules_[SpvOpFUnordGreaterThanEqual].push_back(
870 FoldFClampFeedingCompare(SpvOpFUnordGreaterThanEqual));
871
872 rules_[SpvOpVectorShuffle].push_back(FoldVectorShuffleWithConstants());
873 rules_[SpvOpVectorTimesScalar].push_back(FoldVectorTimesScalar());
874
875 rules_[SpvOpFNegate].push_back(FoldFNegate());
Ben Claytonb73b7602019-07-29 13:56:13 +0100876 rules_[SpvOpQuantizeToF16].push_back(FoldQuantizeToF16());
Chris Forbescc5697f2019-01-30 11:54:08 -0800877}
878} // namespace opt
879} // namespace spvtools