blob: a9830a25bbbfcfc800065fd92d4ed2ec4da13976 [file] [log] [blame]
Chris Forbescc5697f2019-01-30 11:54:08 -08001// Copyright (c) 2018 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#include "source/opt/const_folding_rules.h"
16
17#include "source/opt/ir_context.h"
18
19namespace spvtools {
20namespace opt {
21namespace {
22
23const uint32_t kExtractCompositeIdInIdx = 0;
24
Nicolas Capens6cacf182021-11-30 11:15:46 -050025// Returns a constants with the value NaN of the given type. Only works for
26// 32-bit and 64-bit float point types. Returns |nullptr| if an error occurs.
27const analysis::Constant* GetNan(const analysis::Type* type,
28 analysis::ConstantManager* const_mgr) {
29 const analysis::Float* float_type = type->AsFloat();
30 if (float_type == nullptr) {
31 return nullptr;
32 }
33
34 switch (float_type->width()) {
35 case 32:
36 return const_mgr->GetFloatConst(std::numeric_limits<float>::quiet_NaN());
37 case 64:
38 return const_mgr->GetDoubleConst(
39 std::numeric_limits<double>::quiet_NaN());
40 default:
41 return nullptr;
42 }
43}
44
45// Returns a constants with the value INF of the given type. Only works for
46// 32-bit and 64-bit float point types. Returns |nullptr| if an error occurs.
47const analysis::Constant* GetInf(const analysis::Type* type,
48 analysis::ConstantManager* const_mgr) {
49 const analysis::Float* float_type = type->AsFloat();
50 if (float_type == nullptr) {
51 return nullptr;
52 }
53
54 switch (float_type->width()) {
55 case 32:
56 return const_mgr->GetFloatConst(std::numeric_limits<float>::infinity());
57 case 64:
58 return const_mgr->GetDoubleConst(std::numeric_limits<double>::infinity());
59 default:
60 return nullptr;
61 }
62}
63
Chris Forbescc5697f2019-01-30 11:54:08 -080064// Returns true if |type| is Float or a vector of Float.
65bool HasFloatingPoint(const analysis::Type* type) {
66 if (type->AsFloat()) {
67 return true;
68 } else if (const analysis::Vector* vec_type = type->AsVector()) {
69 return vec_type->element_type()->AsFloat() != nullptr;
70 }
71
72 return false;
73}
74
Nicolas Capens6cacf182021-11-30 11:15:46 -050075// Returns a constants with the value |-val| of the given type. Only works for
76// 32-bit and 64-bit float point types. Returns |nullptr| if an error occurs.
77const analysis::Constant* negateFPConst(const analysis::Type* result_type,
78 const analysis::Constant* val,
79 analysis::ConstantManager* const_mgr) {
80 const analysis::Float* float_type = result_type->AsFloat();
81 assert(float_type != nullptr);
82 if (float_type->width() == 32) {
83 float fa = val->GetFloat();
84 return const_mgr->GetFloatConst(-fa);
85 } else if (float_type->width() == 64) {
86 double da = val->GetDouble();
87 return const_mgr->GetDoubleConst(-da);
88 }
89 return nullptr;
90}
91
Chris Forbescc5697f2019-01-30 11:54:08 -080092// Folds an OpcompositeExtract where input is a composite constant.
93ConstantFoldingRule FoldExtractWithConstants() {
94 return [](IRContext* context, Instruction* inst,
95 const std::vector<const analysis::Constant*>& constants)
96 -> const analysis::Constant* {
97 const analysis::Constant* c = constants[kExtractCompositeIdInIdx];
98 if (c == nullptr) {
99 return nullptr;
100 }
101
102 for (uint32_t i = 1; i < inst->NumInOperands(); ++i) {
103 uint32_t element_index = inst->GetSingleWordInOperand(i);
104 if (c->AsNullConstant()) {
105 // Return Null for the return type.
106 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
107 analysis::TypeManager* type_mgr = context->get_type_mgr();
108 return const_mgr->GetConstant(type_mgr->GetType(inst->type_id()), {});
109 }
110
111 auto cc = c->AsCompositeConstant();
112 assert(cc != nullptr);
113 auto components = cc->GetComponents();
Ben Claytond0f684e2019-08-30 22:36:08 +0100114 // Protect against invalid IR. Refuse to fold if the index is out
115 // of bounds.
116 if (element_index >= components.size()) return nullptr;
Chris Forbescc5697f2019-01-30 11:54:08 -0800117 c = components[element_index];
118 }
119 return c;
120 };
121}
122
123ConstantFoldingRule FoldVectorShuffleWithConstants() {
124 return [](IRContext* context, Instruction* inst,
125 const std::vector<const analysis::Constant*>& constants)
126 -> const analysis::Constant* {
127 assert(inst->opcode() == SpvOpVectorShuffle);
128 const analysis::Constant* c1 = constants[0];
129 const analysis::Constant* c2 = constants[1];
130 if (c1 == nullptr || c2 == nullptr) {
131 return nullptr;
132 }
133
134 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
135 const analysis::Type* element_type = c1->type()->AsVector()->element_type();
136
137 std::vector<const analysis::Constant*> c1_components;
138 if (const analysis::VectorConstant* vec_const = c1->AsVectorConstant()) {
139 c1_components = vec_const->GetComponents();
140 } else {
141 assert(c1->AsNullConstant());
142 const analysis::Constant* element =
143 const_mgr->GetConstant(element_type, {});
144 c1_components.resize(c1->type()->AsVector()->element_count(), element);
145 }
146 std::vector<const analysis::Constant*> c2_components;
147 if (const analysis::VectorConstant* vec_const = c2->AsVectorConstant()) {
148 c2_components = vec_const->GetComponents();
149 } else {
150 assert(c2->AsNullConstant());
151 const analysis::Constant* element =
152 const_mgr->GetConstant(element_type, {});
153 c2_components.resize(c2->type()->AsVector()->element_count(), element);
154 }
155
156 std::vector<uint32_t> ids;
157 const uint32_t undef_literal_value = 0xffffffff;
158 for (uint32_t i = 2; i < inst->NumInOperands(); ++i) {
159 uint32_t index = inst->GetSingleWordInOperand(i);
160 if (index == undef_literal_value) {
161 // Don't fold shuffle with undef literal value.
162 return nullptr;
163 } else if (index < c1_components.size()) {
164 Instruction* member_inst =
165 const_mgr->GetDefiningInstruction(c1_components[index]);
166 ids.push_back(member_inst->result_id());
167 } else {
168 Instruction* member_inst = const_mgr->GetDefiningInstruction(
169 c2_components[index - c1_components.size()]);
170 ids.push_back(member_inst->result_id());
171 }
172 }
173
174 analysis::TypeManager* type_mgr = context->get_type_mgr();
175 return const_mgr->GetConstant(type_mgr->GetType(inst->type_id()), ids);
176 };
177}
178
179ConstantFoldingRule FoldVectorTimesScalar() {
180 return [](IRContext* context, Instruction* inst,
181 const std::vector<const analysis::Constant*>& constants)
182 -> const analysis::Constant* {
183 assert(inst->opcode() == SpvOpVectorTimesScalar);
184 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
185 analysis::TypeManager* type_mgr = context->get_type_mgr();
186
187 if (!inst->IsFloatingPointFoldingAllowed()) {
188 if (HasFloatingPoint(type_mgr->GetType(inst->type_id()))) {
189 return nullptr;
190 }
191 }
192
193 const analysis::Constant* c1 = constants[0];
194 const analysis::Constant* c2 = constants[1];
195
196 if (c1 && c1->IsZero()) {
197 return c1;
198 }
199
200 if (c2 && c2->IsZero()) {
201 // Get or create the NullConstant for this type.
202 std::vector<uint32_t> ids;
203 return const_mgr->GetConstant(type_mgr->GetType(inst->type_id()), ids);
204 }
205
206 if (c1 == nullptr || c2 == nullptr) {
207 return nullptr;
208 }
209
210 // Check result type.
211 const analysis::Type* result_type = type_mgr->GetType(inst->type_id());
212 const analysis::Vector* vector_type = result_type->AsVector();
213 assert(vector_type != nullptr);
214 const analysis::Type* element_type = vector_type->element_type();
215 assert(element_type != nullptr);
216 const analysis::Float* float_type = element_type->AsFloat();
217 assert(float_type != nullptr);
218
219 // Check types of c1 and c2.
220 assert(c1->type()->AsVector() == vector_type);
221 assert(c1->type()->AsVector()->element_type() == element_type &&
222 c2->type() == element_type);
223
224 // Get a float vector that is the result of vector-times-scalar.
225 std::vector<const analysis::Constant*> c1_components =
226 c1->GetVectorComponents(const_mgr);
227 std::vector<uint32_t> ids;
228 if (float_type->width() == 32) {
229 float scalar = c2->GetFloat();
230 for (uint32_t i = 0; i < c1_components.size(); ++i) {
231 utils::FloatProxy<float> result(c1_components[i]->GetFloat() * scalar);
232 std::vector<uint32_t> words = result.GetWords();
233 const analysis::Constant* new_elem =
234 const_mgr->GetConstant(float_type, words);
235 ids.push_back(const_mgr->GetDefiningInstruction(new_elem)->result_id());
236 }
237 return const_mgr->GetConstant(vector_type, ids);
238 } else if (float_type->width() == 64) {
239 double scalar = c2->GetDouble();
240 for (uint32_t i = 0; i < c1_components.size(); ++i) {
241 utils::FloatProxy<double> result(c1_components[i]->GetDouble() *
242 scalar);
243 std::vector<uint32_t> words = result.GetWords();
244 const analysis::Constant* new_elem =
245 const_mgr->GetConstant(float_type, words);
246 ids.push_back(const_mgr->GetDefiningInstruction(new_elem)->result_id());
247 }
248 return const_mgr->GetConstant(vector_type, ids);
249 }
250 return nullptr;
251 };
252}
253
254ConstantFoldingRule FoldCompositeWithConstants() {
255 // Folds an OpCompositeConstruct where all of the inputs are constants to a
256 // constant. A new constant is created if necessary.
257 return [](IRContext* context, Instruction* inst,
258 const std::vector<const analysis::Constant*>& constants)
259 -> const analysis::Constant* {
260 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
261 analysis::TypeManager* type_mgr = context->get_type_mgr();
262 const analysis::Type* new_type = type_mgr->GetType(inst->type_id());
263 Instruction* type_inst =
264 context->get_def_use_mgr()->GetDef(inst->type_id());
265
266 std::vector<uint32_t> ids;
267 for (uint32_t i = 0; i < constants.size(); ++i) {
268 const analysis::Constant* element_const = constants[i];
269 if (element_const == nullptr) {
270 return nullptr;
271 }
272
273 uint32_t component_type_id = 0;
274 if (type_inst->opcode() == SpvOpTypeStruct) {
275 component_type_id = type_inst->GetSingleWordInOperand(i);
276 } else if (type_inst->opcode() == SpvOpTypeArray) {
277 component_type_id = type_inst->GetSingleWordInOperand(0);
278 }
279
280 uint32_t element_id =
281 const_mgr->FindDeclaredConstant(element_const, component_type_id);
282 if (element_id == 0) {
283 return nullptr;
284 }
285 ids.push_back(element_id);
286 }
287 return const_mgr->GetConstant(new_type, ids);
288 };
289}
290
291// The interface for a function that returns the result of applying a scalar
292// floating-point binary operation on |a| and |b|. The type of the return value
293// will be |type|. The input constants must also be of type |type|.
294using UnaryScalarFoldingRule = std::function<const analysis::Constant*(
295 const analysis::Type* result_type, const analysis::Constant* a,
296 analysis::ConstantManager*)>;
297
298// The interface for a function that returns the result of applying a scalar
299// floating-point binary operation on |a| and |b|. The type of the return value
300// will be |type|. The input constants must also be of type |type|.
301using BinaryScalarFoldingRule = std::function<const analysis::Constant*(
302 const analysis::Type* result_type, const analysis::Constant* a,
303 const analysis::Constant* b, analysis::ConstantManager*)>;
304
305// Returns a |ConstantFoldingRule| that folds unary floating point scalar ops
306// using |scalar_rule| and unary float point vectors ops by applying
307// |scalar_rule| to the elements of the vector. The |ConstantFoldingRule|
308// that is returned assumes that |constants| contains 1 entry. If they are
309// not |nullptr|, then their type is either |Float| or |Integer| or a |Vector|
310// whose element type is |Float| or |Integer|.
311ConstantFoldingRule FoldFPUnaryOp(UnaryScalarFoldingRule scalar_rule) {
312 return [scalar_rule](IRContext* context, Instruction* inst,
313 const std::vector<const analysis::Constant*>& constants)
314 -> const analysis::Constant* {
315 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
316 analysis::TypeManager* type_mgr = context->get_type_mgr();
317 const analysis::Type* result_type = type_mgr->GetType(inst->type_id());
318 const analysis::Vector* vector_type = result_type->AsVector();
319
320 if (!inst->IsFloatingPointFoldingAllowed()) {
321 return nullptr;
322 }
323
Ben Claytondc6b76a2020-02-24 14:53:40 +0000324 const analysis::Constant* arg =
325 (inst->opcode() == SpvOpExtInst) ? constants[1] : constants[0];
326
327 if (arg == nullptr) {
Chris Forbescc5697f2019-01-30 11:54:08 -0800328 return nullptr;
329 }
330
331 if (vector_type != nullptr) {
332 std::vector<const analysis::Constant*> a_components;
333 std::vector<const analysis::Constant*> results_components;
334
Ben Claytondc6b76a2020-02-24 14:53:40 +0000335 a_components = arg->GetVectorComponents(const_mgr);
Chris Forbescc5697f2019-01-30 11:54:08 -0800336
337 // Fold each component of the vector.
338 for (uint32_t i = 0; i < a_components.size(); ++i) {
339 results_components.push_back(scalar_rule(vector_type->element_type(),
340 a_components[i], const_mgr));
341 if (results_components[i] == nullptr) {
342 return nullptr;
343 }
344 }
345
346 // Build the constant object and return it.
347 std::vector<uint32_t> ids;
348 for (const analysis::Constant* member : results_components) {
349 ids.push_back(const_mgr->GetDefiningInstruction(member)->result_id());
350 }
351 return const_mgr->GetConstant(vector_type, ids);
352 } else {
Ben Claytondc6b76a2020-02-24 14:53:40 +0000353 return scalar_rule(result_type, arg, const_mgr);
Chris Forbescc5697f2019-01-30 11:54:08 -0800354 }
355 };
356}
357
Ben Claytond552f632019-11-18 11:18:41 +0000358// Returns the result of folding the constants in |constants| according the
359// |scalar_rule|. If |result_type| is a vector, then |scalar_rule| is applied
360// per component.
361const analysis::Constant* FoldFPBinaryOp(
362 BinaryScalarFoldingRule scalar_rule, uint32_t result_type_id,
363 const std::vector<const analysis::Constant*>& constants,
364 IRContext* context) {
365 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
366 analysis::TypeManager* type_mgr = context->get_type_mgr();
367 const analysis::Type* result_type = type_mgr->GetType(result_type_id);
368 const analysis::Vector* vector_type = result_type->AsVector();
369
370 if (constants[0] == nullptr || constants[1] == nullptr) {
371 return nullptr;
372 }
373
374 if (vector_type != nullptr) {
375 std::vector<const analysis::Constant*> a_components;
376 std::vector<const analysis::Constant*> b_components;
377 std::vector<const analysis::Constant*> results_components;
378
379 a_components = constants[0]->GetVectorComponents(const_mgr);
380 b_components = constants[1]->GetVectorComponents(const_mgr);
381
382 // Fold each component of the vector.
383 for (uint32_t i = 0; i < a_components.size(); ++i) {
384 results_components.push_back(scalar_rule(vector_type->element_type(),
385 a_components[i], b_components[i],
386 const_mgr));
387 if (results_components[i] == nullptr) {
388 return nullptr;
389 }
390 }
391
392 // Build the constant object and return it.
393 std::vector<uint32_t> ids;
394 for (const analysis::Constant* member : results_components) {
395 ids.push_back(const_mgr->GetDefiningInstruction(member)->result_id());
396 }
397 return const_mgr->GetConstant(vector_type, ids);
398 } else {
399 return scalar_rule(result_type, constants[0], constants[1], const_mgr);
400 }
401}
402
Chris Forbescc5697f2019-01-30 11:54:08 -0800403// Returns a |ConstantFoldingRule| that folds floating point scalars using
404// |scalar_rule| and vectors of floating point by applying |scalar_rule| to the
405// elements of the vector. The |ConstantFoldingRule| that is returned assumes
406// that |constants| contains 2 entries. If they are not |nullptr|, then their
407// type is either |Float| or a |Vector| whose element type is |Float|.
408ConstantFoldingRule FoldFPBinaryOp(BinaryScalarFoldingRule scalar_rule) {
409 return [scalar_rule](IRContext* context, Instruction* inst,
410 const std::vector<const analysis::Constant*>& constants)
411 -> const analysis::Constant* {
Chris Forbescc5697f2019-01-30 11:54:08 -0800412 if (!inst->IsFloatingPointFoldingAllowed()) {
413 return nullptr;
414 }
Ben Claytond552f632019-11-18 11:18:41 +0000415 if (inst->opcode() == SpvOpExtInst) {
416 return FoldFPBinaryOp(scalar_rule, inst->type_id(),
417 {constants[1], constants[2]}, context);
Chris Forbescc5697f2019-01-30 11:54:08 -0800418 }
Ben Claytond552f632019-11-18 11:18:41 +0000419 return FoldFPBinaryOp(scalar_rule, inst->type_id(), constants, context);
Chris Forbescc5697f2019-01-30 11:54:08 -0800420 };
421}
422
423// This macro defines a |UnaryScalarFoldingRule| that performs float to
424// integer conversion.
425// TODO(greg-lunarg): Support for 64-bit integer types.
426UnaryScalarFoldingRule FoldFToIOp() {
427 return [](const analysis::Type* result_type, const analysis::Constant* a,
428 analysis::ConstantManager* const_mgr) -> const analysis::Constant* {
429 assert(result_type != nullptr && a != nullptr);
430 const analysis::Integer* integer_type = result_type->AsInteger();
431 const analysis::Float* float_type = a->type()->AsFloat();
432 assert(float_type != nullptr);
433 assert(integer_type != nullptr);
434 if (integer_type->width() != 32) return nullptr;
435 if (float_type->width() == 32) {
436 float fa = a->GetFloat();
437 uint32_t result = integer_type->IsSigned()
438 ? static_cast<uint32_t>(static_cast<int32_t>(fa))
439 : static_cast<uint32_t>(fa);
440 std::vector<uint32_t> words = {result};
441 return const_mgr->GetConstant(result_type, words);
442 } else if (float_type->width() == 64) {
443 double fa = a->GetDouble();
444 uint32_t result = integer_type->IsSigned()
445 ? static_cast<uint32_t>(static_cast<int32_t>(fa))
446 : static_cast<uint32_t>(fa);
447 std::vector<uint32_t> words = {result};
448 return const_mgr->GetConstant(result_type, words);
449 }
450 return nullptr;
451 };
452}
453
454// This function defines a |UnaryScalarFoldingRule| that performs integer to
455// float conversion.
456// TODO(greg-lunarg): Support for 64-bit integer types.
457UnaryScalarFoldingRule FoldIToFOp() {
458 return [](const analysis::Type* result_type, const analysis::Constant* a,
459 analysis::ConstantManager* const_mgr) -> const analysis::Constant* {
460 assert(result_type != nullptr && a != nullptr);
461 const analysis::Integer* integer_type = a->type()->AsInteger();
462 const analysis::Float* float_type = result_type->AsFloat();
463 assert(float_type != nullptr);
464 assert(integer_type != nullptr);
465 if (integer_type->width() != 32) return nullptr;
466 uint32_t ua = a->GetU32();
467 if (float_type->width() == 32) {
468 float result_val = integer_type->IsSigned()
469 ? static_cast<float>(static_cast<int32_t>(ua))
470 : static_cast<float>(ua);
471 utils::FloatProxy<float> result(result_val);
472 std::vector<uint32_t> words = {result.data()};
473 return const_mgr->GetConstant(result_type, words);
474 } else if (float_type->width() == 64) {
475 double result_val = integer_type->IsSigned()
476 ? static_cast<double>(static_cast<int32_t>(ua))
477 : static_cast<double>(ua);
478 utils::FloatProxy<double> result(result_val);
479 std::vector<uint32_t> words = result.GetWords();
480 return const_mgr->GetConstant(result_type, words);
481 }
482 return nullptr;
483 };
484}
485
Ben Claytonb73b7602019-07-29 13:56:13 +0100486// This defines a |UnaryScalarFoldingRule| that performs |OpQuantizeToF16|.
487UnaryScalarFoldingRule FoldQuantizeToF16Scalar() {
488 return [](const analysis::Type* result_type, const analysis::Constant* a,
489 analysis::ConstantManager* const_mgr) -> const analysis::Constant* {
490 assert(result_type != nullptr && a != nullptr);
491 const analysis::Float* float_type = a->type()->AsFloat();
492 assert(float_type != nullptr);
493 if (float_type->width() != 32) {
494 return nullptr;
495 }
496
497 float fa = a->GetFloat();
498 utils::HexFloat<utils::FloatProxy<float>> orignal(fa);
499 utils::HexFloat<utils::FloatProxy<utils::Float16>> quantized(0);
500 utils::HexFloat<utils::FloatProxy<float>> result(0.0f);
501 orignal.castTo(quantized, utils::round_direction::kToZero);
502 quantized.castTo(result, utils::round_direction::kToZero);
503 std::vector<uint32_t> words = {result.getBits()};
504 return const_mgr->GetConstant(result_type, words);
505 };
506}
507
Chris Forbescc5697f2019-01-30 11:54:08 -0800508// This macro defines a |BinaryScalarFoldingRule| that applies |op|. The
509// operator |op| must work for both float and double, and use syntax "f1 op f2".
Ben Claytond552f632019-11-18 11:18:41 +0000510#define FOLD_FPARITH_OP(op) \
511 [](const analysis::Type* result_type_in_macro, const analysis::Constant* a, \
512 const analysis::Constant* b, \
513 analysis::ConstantManager* const_mgr_in_macro) \
514 -> const analysis::Constant* { \
515 assert(result_type_in_macro != nullptr && a != nullptr && b != nullptr); \
516 assert(result_type_in_macro == a->type() && \
517 result_type_in_macro == b->type()); \
518 const analysis::Float* float_type_in_macro = \
519 result_type_in_macro->AsFloat(); \
520 assert(float_type_in_macro != nullptr); \
521 if (float_type_in_macro->width() == 32) { \
522 float fa = a->GetFloat(); \
523 float fb = b->GetFloat(); \
524 utils::FloatProxy<float> result_in_macro(fa op fb); \
525 std::vector<uint32_t> words_in_macro = result_in_macro.GetWords(); \
526 return const_mgr_in_macro->GetConstant(result_type_in_macro, \
527 words_in_macro); \
528 } else if (float_type_in_macro->width() == 64) { \
529 double fa = a->GetDouble(); \
530 double fb = b->GetDouble(); \
531 utils::FloatProxy<double> result_in_macro(fa op fb); \
532 std::vector<uint32_t> words_in_macro = result_in_macro.GetWords(); \
533 return const_mgr_in_macro->GetConstant(result_type_in_macro, \
534 words_in_macro); \
535 } \
536 return nullptr; \
Chris Forbescc5697f2019-01-30 11:54:08 -0800537 }
538
539// Define the folding rule for conversion between floating point and integer
540ConstantFoldingRule FoldFToI() { return FoldFPUnaryOp(FoldFToIOp()); }
541ConstantFoldingRule FoldIToF() { return FoldFPUnaryOp(FoldIToFOp()); }
Ben Claytonb73b7602019-07-29 13:56:13 +0100542ConstantFoldingRule FoldQuantizeToF16() {
543 return FoldFPUnaryOp(FoldQuantizeToF16Scalar());
544}
Chris Forbescc5697f2019-01-30 11:54:08 -0800545
546// Define the folding rules for subtraction, addition, multiplication, and
547// division for floating point values.
548ConstantFoldingRule FoldFSub() { return FoldFPBinaryOp(FOLD_FPARITH_OP(-)); }
549ConstantFoldingRule FoldFAdd() { return FoldFPBinaryOp(FOLD_FPARITH_OP(+)); }
550ConstantFoldingRule FoldFMul() { return FoldFPBinaryOp(FOLD_FPARITH_OP(*)); }
Nicolas Capens6cacf182021-11-30 11:15:46 -0500551
552// Returns the constant that results from evaluating |numerator| / 0.0. Returns
553// |nullptr| if the result could not be evalutated.
554const analysis::Constant* FoldFPScalarDivideByZero(
555 const analysis::Type* result_type, const analysis::Constant* numerator,
556 analysis::ConstantManager* const_mgr) {
557 if (numerator == nullptr) {
558 return nullptr;
559 }
560
561 if (numerator->IsZero()) {
562 return GetNan(result_type, const_mgr);
563 }
564
565 const analysis::Constant* result = GetInf(result_type, const_mgr);
566 if (result == nullptr) {
567 return nullptr;
568 }
569
570 if (numerator->AsFloatConstant()->GetValueAsDouble() < 0.0) {
571 result = negateFPConst(result_type, result, const_mgr);
572 }
573 return result;
574}
575
576// Returns the result of folding |numerator| / |denominator|. Returns |nullptr|
577// if it cannot be folded.
578const analysis::Constant* FoldScalarFPDivide(
579 const analysis::Type* result_type, const analysis::Constant* numerator,
580 const analysis::Constant* denominator,
581 analysis::ConstantManager* const_mgr) {
582 if (denominator == nullptr) {
583 return nullptr;
584 }
585
586 if (denominator->IsZero()) {
587 return FoldFPScalarDivideByZero(result_type, numerator, const_mgr);
588 }
589
590 const analysis::FloatConstant* denominator_float =
591 denominator->AsFloatConstant();
592 if (denominator_float && denominator->GetValueAsDouble() == -0.0) {
593 const analysis::Constant* result =
594 FoldFPScalarDivideByZero(result_type, numerator, const_mgr);
595 if (result != nullptr)
596 result = negateFPConst(result_type, result, const_mgr);
597 return result;
598 } else {
599 return FOLD_FPARITH_OP(/)(result_type, numerator, denominator, const_mgr);
600 }
601}
602
603// Returns the constant folding rule to fold |OpFDiv| with two constants.
604ConstantFoldingRule FoldFDiv() { return FoldFPBinaryOp(FoldScalarFPDivide); }
Chris Forbescc5697f2019-01-30 11:54:08 -0800605
606bool CompareFloatingPoint(bool op_result, bool op_unordered,
607 bool need_ordered) {
608 if (need_ordered) {
609 // operands are ordered and Operand 1 is |op| Operand 2
610 return !op_unordered && op_result;
611 } else {
612 // operands are unordered or Operand 1 is |op| Operand 2
613 return op_unordered || op_result;
614 }
615}
616
617// This macro defines a |BinaryScalarFoldingRule| that applies |op|. The
618// operator |op| must work for both float and double, and use syntax "f1 op f2".
619#define FOLD_FPCMP_OP(op, ord) \
620 [](const analysis::Type* result_type, const analysis::Constant* a, \
621 const analysis::Constant* b, \
622 analysis::ConstantManager* const_mgr) -> const analysis::Constant* { \
623 assert(result_type != nullptr && a != nullptr && b != nullptr); \
624 assert(result_type->AsBool()); \
625 assert(a->type() == b->type()); \
626 const analysis::Float* float_type = a->type()->AsFloat(); \
627 assert(float_type != nullptr); \
628 if (float_type->width() == 32) { \
629 float fa = a->GetFloat(); \
630 float fb = b->GetFloat(); \
631 bool result = CompareFloatingPoint( \
632 fa op fb, std::isnan(fa) || std::isnan(fb), ord); \
633 std::vector<uint32_t> words = {uint32_t(result)}; \
634 return const_mgr->GetConstant(result_type, words); \
635 } else if (float_type->width() == 64) { \
636 double fa = a->GetDouble(); \
637 double fb = b->GetDouble(); \
638 bool result = CompareFloatingPoint( \
639 fa op fb, std::isnan(fa) || std::isnan(fb), ord); \
640 std::vector<uint32_t> words = {uint32_t(result)}; \
641 return const_mgr->GetConstant(result_type, words); \
642 } \
643 return nullptr; \
644 }
645
646// Define the folding rules for ordered and unordered comparison for floating
647// point values.
648ConstantFoldingRule FoldFOrdEqual() {
649 return FoldFPBinaryOp(FOLD_FPCMP_OP(==, true));
650}
651ConstantFoldingRule FoldFUnordEqual() {
652 return FoldFPBinaryOp(FOLD_FPCMP_OP(==, false));
653}
654ConstantFoldingRule FoldFOrdNotEqual() {
655 return FoldFPBinaryOp(FOLD_FPCMP_OP(!=, true));
656}
657ConstantFoldingRule FoldFUnordNotEqual() {
658 return FoldFPBinaryOp(FOLD_FPCMP_OP(!=, false));
659}
660ConstantFoldingRule FoldFOrdLessThan() {
661 return FoldFPBinaryOp(FOLD_FPCMP_OP(<, true));
662}
663ConstantFoldingRule FoldFUnordLessThan() {
664 return FoldFPBinaryOp(FOLD_FPCMP_OP(<, false));
665}
666ConstantFoldingRule FoldFOrdGreaterThan() {
667 return FoldFPBinaryOp(FOLD_FPCMP_OP(>, true));
668}
669ConstantFoldingRule FoldFUnordGreaterThan() {
670 return FoldFPBinaryOp(FOLD_FPCMP_OP(>, false));
671}
672ConstantFoldingRule FoldFOrdLessThanEqual() {
673 return FoldFPBinaryOp(FOLD_FPCMP_OP(<=, true));
674}
675ConstantFoldingRule FoldFUnordLessThanEqual() {
676 return FoldFPBinaryOp(FOLD_FPCMP_OP(<=, false));
677}
678ConstantFoldingRule FoldFOrdGreaterThanEqual() {
679 return FoldFPBinaryOp(FOLD_FPCMP_OP(>=, true));
680}
681ConstantFoldingRule FoldFUnordGreaterThanEqual() {
682 return FoldFPBinaryOp(FOLD_FPCMP_OP(>=, false));
683}
684
685// Folds an OpDot where all of the inputs are constants to a
686// constant. A new constant is created if necessary.
687ConstantFoldingRule FoldOpDotWithConstants() {
688 return [](IRContext* context, Instruction* inst,
689 const std::vector<const analysis::Constant*>& constants)
690 -> const analysis::Constant* {
691 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
692 analysis::TypeManager* type_mgr = context->get_type_mgr();
693 const analysis::Type* new_type = type_mgr->GetType(inst->type_id());
694 assert(new_type->AsFloat() && "OpDot should have a float return type.");
695 const analysis::Float* float_type = new_type->AsFloat();
696
697 if (!inst->IsFloatingPointFoldingAllowed()) {
698 return nullptr;
699 }
700
701 // If one of the operands is 0, then the result is 0.
702 bool has_zero_operand = false;
703
704 for (int i = 0; i < 2; ++i) {
705 if (constants[i]) {
706 if (constants[i]->AsNullConstant() ||
707 constants[i]->AsVectorConstant()->IsZero()) {
708 has_zero_operand = true;
709 break;
710 }
711 }
712 }
713
714 if (has_zero_operand) {
715 if (float_type->width() == 32) {
716 utils::FloatProxy<float> result(0.0f);
717 std::vector<uint32_t> words = result.GetWords();
718 return const_mgr->GetConstant(float_type, words);
719 }
720 if (float_type->width() == 64) {
721 utils::FloatProxy<double> result(0.0);
722 std::vector<uint32_t> words = result.GetWords();
723 return const_mgr->GetConstant(float_type, words);
724 }
725 return nullptr;
726 }
727
728 if (constants[0] == nullptr || constants[1] == nullptr) {
729 return nullptr;
730 }
731
732 std::vector<const analysis::Constant*> a_components;
733 std::vector<const analysis::Constant*> b_components;
734
735 a_components = constants[0]->GetVectorComponents(const_mgr);
736 b_components = constants[1]->GetVectorComponents(const_mgr);
737
738 utils::FloatProxy<double> result(0.0);
739 std::vector<uint32_t> words = result.GetWords();
740 const analysis::Constant* result_const =
741 const_mgr->GetConstant(float_type, words);
Ben Claytonb73b7602019-07-29 13:56:13 +0100742 for (uint32_t i = 0; i < a_components.size() && result_const != nullptr;
743 ++i) {
Chris Forbescc5697f2019-01-30 11:54:08 -0800744 if (a_components[i] == nullptr || b_components[i] == nullptr) {
745 return nullptr;
746 }
747
748 const analysis::Constant* component = FOLD_FPARITH_OP(*)(
749 new_type, a_components[i], b_components[i], const_mgr);
Ben Claytonb73b7602019-07-29 13:56:13 +0100750 if (component == nullptr) {
751 return nullptr;
752 }
Chris Forbescc5697f2019-01-30 11:54:08 -0800753 result_const =
754 FOLD_FPARITH_OP(+)(new_type, result_const, component, const_mgr);
755 }
756 return result_const;
757 };
758}
759
760// This function defines a |UnaryScalarFoldingRule| that subtracts the constant
761// from zero.
762UnaryScalarFoldingRule FoldFNegateOp() {
763 return [](const analysis::Type* result_type, const analysis::Constant* a,
764 analysis::ConstantManager* const_mgr) -> const analysis::Constant* {
765 assert(result_type != nullptr && a != nullptr);
766 assert(result_type == a->type());
Nicolas Capens6cacf182021-11-30 11:15:46 -0500767 return negateFPConst(result_type, a, const_mgr);
Chris Forbescc5697f2019-01-30 11:54:08 -0800768 };
769}
770
771ConstantFoldingRule FoldFNegate() { return FoldFPUnaryOp(FoldFNegateOp()); }
772
773ConstantFoldingRule FoldFClampFeedingCompare(uint32_t cmp_opcode) {
774 return [cmp_opcode](IRContext* context, Instruction* inst,
775 const std::vector<const analysis::Constant*>& constants)
776 -> const analysis::Constant* {
777 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
778 analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
779
780 if (!inst->IsFloatingPointFoldingAllowed()) {
781 return nullptr;
782 }
783
784 uint32_t non_const_idx = (constants[0] ? 1 : 0);
785 uint32_t operand_id = inst->GetSingleWordInOperand(non_const_idx);
786 Instruction* operand_inst = def_use_mgr->GetDef(operand_id);
787
788 analysis::TypeManager* type_mgr = context->get_type_mgr();
789 const analysis::Type* operand_type =
790 type_mgr->GetType(operand_inst->type_id());
791
792 if (!operand_type->AsFloat()) {
793 return nullptr;
794 }
795
796 if (operand_type->AsFloat()->width() != 32 &&
797 operand_type->AsFloat()->width() != 64) {
798 return nullptr;
799 }
800
801 if (operand_inst->opcode() != SpvOpExtInst) {
802 return nullptr;
803 }
804
805 if (operand_inst->GetSingleWordInOperand(1) != GLSLstd450FClamp) {
806 return nullptr;
807 }
808
809 if (constants[1] == nullptr && constants[0] == nullptr) {
810 return nullptr;
811 }
812
813 uint32_t max_id = operand_inst->GetSingleWordInOperand(4);
814 const analysis::Constant* max_const =
815 const_mgr->FindDeclaredConstant(max_id);
816
817 uint32_t min_id = operand_inst->GetSingleWordInOperand(3);
818 const analysis::Constant* min_const =
819 const_mgr->FindDeclaredConstant(min_id);
820
821 bool found_result = false;
822 bool result = false;
823
824 switch (cmp_opcode) {
825 case SpvOpFOrdLessThan:
826 case SpvOpFUnordLessThan:
827 case SpvOpFOrdGreaterThanEqual:
828 case SpvOpFUnordGreaterThanEqual:
829 if (constants[0]) {
830 if (min_const) {
831 if (constants[0]->GetValueAsDouble() <
832 min_const->GetValueAsDouble()) {
833 found_result = true;
834 result = (cmp_opcode == SpvOpFOrdLessThan ||
835 cmp_opcode == SpvOpFUnordLessThan);
836 }
837 }
838 if (max_const) {
839 if (constants[0]->GetValueAsDouble() >=
840 max_const->GetValueAsDouble()) {
841 found_result = true;
842 result = !(cmp_opcode == SpvOpFOrdLessThan ||
843 cmp_opcode == SpvOpFUnordLessThan);
844 }
845 }
846 }
847
848 if (constants[1]) {
849 if (max_const) {
850 if (max_const->GetValueAsDouble() <
851 constants[1]->GetValueAsDouble()) {
852 found_result = true;
853 result = (cmp_opcode == SpvOpFOrdLessThan ||
854 cmp_opcode == SpvOpFUnordLessThan);
855 }
856 }
857
858 if (min_const) {
859 if (min_const->GetValueAsDouble() >=
860 constants[1]->GetValueAsDouble()) {
861 found_result = true;
862 result = !(cmp_opcode == SpvOpFOrdLessThan ||
863 cmp_opcode == SpvOpFUnordLessThan);
864 }
865 }
866 }
867 break;
868 case SpvOpFOrdGreaterThan:
869 case SpvOpFUnordGreaterThan:
870 case SpvOpFOrdLessThanEqual:
871 case SpvOpFUnordLessThanEqual:
872 if (constants[0]) {
873 if (min_const) {
874 if (constants[0]->GetValueAsDouble() <=
875 min_const->GetValueAsDouble()) {
876 found_result = true;
877 result = (cmp_opcode == SpvOpFOrdLessThanEqual ||
878 cmp_opcode == SpvOpFUnordLessThanEqual);
879 }
880 }
881 if (max_const) {
882 if (constants[0]->GetValueAsDouble() >
883 max_const->GetValueAsDouble()) {
884 found_result = true;
885 result = !(cmp_opcode == SpvOpFOrdLessThanEqual ||
886 cmp_opcode == SpvOpFUnordLessThanEqual);
887 }
888 }
889 }
890
891 if (constants[1]) {
892 if (max_const) {
893 if (max_const->GetValueAsDouble() <=
894 constants[1]->GetValueAsDouble()) {
895 found_result = true;
896 result = (cmp_opcode == SpvOpFOrdLessThanEqual ||
897 cmp_opcode == SpvOpFUnordLessThanEqual);
898 }
899 }
900
901 if (min_const) {
902 if (min_const->GetValueAsDouble() >
903 constants[1]->GetValueAsDouble()) {
904 found_result = true;
905 result = !(cmp_opcode == SpvOpFOrdLessThanEqual ||
906 cmp_opcode == SpvOpFUnordLessThanEqual);
907 }
908 }
909 }
910 break;
911 default:
912 return nullptr;
913 }
914
915 if (!found_result) {
916 return nullptr;
917 }
918
919 const analysis::Type* bool_type =
920 context->get_type_mgr()->GetType(inst->type_id());
921 const analysis::Constant* result_const =
922 const_mgr->GetConstant(bool_type, {static_cast<uint32_t>(result)});
923 assert(result_const);
924 return result_const;
925 };
926}
927
Ben Claytond0f684e2019-08-30 22:36:08 +0100928ConstantFoldingRule FoldFMix() {
929 return [](IRContext* context, Instruction* inst,
930 const std::vector<const analysis::Constant*>& constants)
931 -> const analysis::Constant* {
932 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
933 assert(inst->opcode() == SpvOpExtInst &&
934 "Expecting an extended instruction.");
935 assert(inst->GetSingleWordInOperand(0) ==
936 context->get_feature_mgr()->GetExtInstImportId_GLSLstd450() &&
937 "Expecting a GLSLstd450 extended instruction.");
938 assert(inst->GetSingleWordInOperand(1) == GLSLstd450FMix &&
939 "Expecting and FMix instruction.");
940
941 if (!inst->IsFloatingPointFoldingAllowed()) {
942 return nullptr;
943 }
944
945 // Make sure all FMix operands are constants.
946 for (uint32_t i = 1; i < 4; i++) {
947 if (constants[i] == nullptr) {
948 return nullptr;
949 }
950 }
951
952 const analysis::Constant* one;
Ben Claytond552f632019-11-18 11:18:41 +0000953 bool is_vector = false;
954 const analysis::Type* result_type = constants[1]->type();
955 const analysis::Type* base_type = result_type;
956 if (base_type->AsVector()) {
957 is_vector = true;
958 base_type = base_type->AsVector()->element_type();
959 }
960 assert(base_type->AsFloat() != nullptr &&
961 "FMix is suppose to act on floats or vectors of floats.");
962
963 if (base_type->AsFloat()->width() == 32) {
964 one = const_mgr->GetConstant(base_type,
Ben Claytond0f684e2019-08-30 22:36:08 +0100965 utils::FloatProxy<float>(1.0f).GetWords());
966 } else {
Ben Claytond552f632019-11-18 11:18:41 +0000967 one = const_mgr->GetConstant(base_type,
Ben Claytond0f684e2019-08-30 22:36:08 +0100968 utils::FloatProxy<double>(1.0).GetWords());
969 }
970
Ben Claytond552f632019-11-18 11:18:41 +0000971 if (is_vector) {
972 uint32_t one_id = const_mgr->GetDefiningInstruction(one)->result_id();
973 one =
974 const_mgr->GetConstant(result_type, std::vector<uint32_t>(4, one_id));
975 }
976
977 const analysis::Constant* temp1 = FoldFPBinaryOp(
978 FOLD_FPARITH_OP(-), inst->type_id(), {one, constants[3]}, context);
Ben Claytond0f684e2019-08-30 22:36:08 +0100979 if (temp1 == nullptr) {
980 return nullptr;
981 }
982
Ben Claytond552f632019-11-18 11:18:41 +0000983 const analysis::Constant* temp2 = FoldFPBinaryOp(
984 FOLD_FPARITH_OP(*), inst->type_id(), {constants[1], temp1}, context);
Ben Claytond0f684e2019-08-30 22:36:08 +0100985 if (temp2 == nullptr) {
986 return nullptr;
987 }
Ben Claytond552f632019-11-18 11:18:41 +0000988 const analysis::Constant* temp3 =
989 FoldFPBinaryOp(FOLD_FPARITH_OP(*), inst->type_id(),
990 {constants[2], constants[3]}, context);
Ben Claytond0f684e2019-08-30 22:36:08 +0100991 if (temp3 == nullptr) {
992 return nullptr;
993 }
Ben Claytond552f632019-11-18 11:18:41 +0000994 return FoldFPBinaryOp(FOLD_FPARITH_OP(+), inst->type_id(), {temp2, temp3},
995 context);
Ben Claytond0f684e2019-08-30 22:36:08 +0100996 };
997}
998
Ben Claytond552f632019-11-18 11:18:41 +0000999template <class IntType>
1000IntType FoldIClamp(IntType x, IntType min_val, IntType max_val) {
1001 if (x < min_val) {
1002 x = min_val;
1003 }
1004 if (x > max_val) {
1005 x = max_val;
1006 }
1007 return x;
1008}
1009
1010const analysis::Constant* FoldMin(const analysis::Type* result_type,
1011 const analysis::Constant* a,
1012 const analysis::Constant* b,
1013 analysis::ConstantManager*) {
1014 if (const analysis::Integer* int_type = result_type->AsInteger()) {
1015 if (int_type->width() == 32) {
1016 if (int_type->IsSigned()) {
1017 int32_t va = a->GetS32();
1018 int32_t vb = b->GetS32();
1019 return (va < vb ? a : b);
1020 } else {
1021 uint32_t va = a->GetU32();
1022 uint32_t vb = b->GetU32();
1023 return (va < vb ? a : b);
1024 }
1025 } else if (int_type->width() == 64) {
1026 if (int_type->IsSigned()) {
1027 int64_t va = a->GetS64();
1028 int64_t vb = b->GetS64();
1029 return (va < vb ? a : b);
1030 } else {
1031 uint64_t va = a->GetU64();
1032 uint64_t vb = b->GetU64();
1033 return (va < vb ? a : b);
1034 }
1035 }
1036 } else if (const analysis::Float* float_type = result_type->AsFloat()) {
1037 if (float_type->width() == 32) {
1038 float va = a->GetFloat();
1039 float vb = b->GetFloat();
1040 return (va < vb ? a : b);
1041 } else if (float_type->width() == 64) {
1042 double va = a->GetDouble();
1043 double vb = b->GetDouble();
1044 return (va < vb ? a : b);
1045 }
1046 }
1047 return nullptr;
1048}
1049
1050const analysis::Constant* FoldMax(const analysis::Type* result_type,
1051 const analysis::Constant* a,
1052 const analysis::Constant* b,
1053 analysis::ConstantManager*) {
1054 if (const analysis::Integer* int_type = result_type->AsInteger()) {
1055 if (int_type->width() == 32) {
1056 if (int_type->IsSigned()) {
1057 int32_t va = a->GetS32();
1058 int32_t vb = b->GetS32();
1059 return (va > vb ? a : b);
1060 } else {
1061 uint32_t va = a->GetU32();
1062 uint32_t vb = b->GetU32();
1063 return (va > vb ? a : b);
1064 }
1065 } else if (int_type->width() == 64) {
1066 if (int_type->IsSigned()) {
1067 int64_t va = a->GetS64();
1068 int64_t vb = b->GetS64();
1069 return (va > vb ? a : b);
1070 } else {
1071 uint64_t va = a->GetU64();
1072 uint64_t vb = b->GetU64();
1073 return (va > vb ? a : b);
1074 }
1075 }
1076 } else if (const analysis::Float* float_type = result_type->AsFloat()) {
1077 if (float_type->width() == 32) {
1078 float va = a->GetFloat();
1079 float vb = b->GetFloat();
1080 return (va > vb ? a : b);
1081 } else if (float_type->width() == 64) {
1082 double va = a->GetDouble();
1083 double vb = b->GetDouble();
1084 return (va > vb ? a : b);
1085 }
1086 }
1087 return nullptr;
1088}
1089
1090// Fold an clamp instruction when all three operands are constant.
1091const analysis::Constant* FoldClamp1(
1092 IRContext* context, Instruction* inst,
1093 const std::vector<const analysis::Constant*>& constants) {
1094 assert(inst->opcode() == SpvOpExtInst &&
1095 "Expecting an extended instruction.");
1096 assert(inst->GetSingleWordInOperand(0) ==
1097 context->get_feature_mgr()->GetExtInstImportId_GLSLstd450() &&
1098 "Expecting a GLSLstd450 extended instruction.");
1099
1100 // Make sure all Clamp operands are constants.
Alexis Hetu00e0af12021-11-08 08:57:46 -05001101 for (uint32_t i = 1; i < 4; i++) {
Ben Claytond552f632019-11-18 11:18:41 +00001102 if (constants[i] == nullptr) {
1103 return nullptr;
1104 }
1105 }
1106
1107 const analysis::Constant* temp = FoldFPBinaryOp(
1108 FoldMax, inst->type_id(), {constants[1], constants[2]}, context);
1109 if (temp == nullptr) {
1110 return nullptr;
1111 }
1112 return FoldFPBinaryOp(FoldMin, inst->type_id(), {temp, constants[3]},
1113 context);
1114}
1115
Alexis Hetu00e0af12021-11-08 08:57:46 -05001116// Fold a clamp instruction when |x <= min_val|.
Ben Claytond552f632019-11-18 11:18:41 +00001117const analysis::Constant* FoldClamp2(
1118 IRContext* context, Instruction* inst,
1119 const std::vector<const analysis::Constant*>& constants) {
1120 assert(inst->opcode() == SpvOpExtInst &&
1121 "Expecting an extended instruction.");
1122 assert(inst->GetSingleWordInOperand(0) ==
1123 context->get_feature_mgr()->GetExtInstImportId_GLSLstd450() &&
1124 "Expecting a GLSLstd450 extended instruction.");
1125
1126 const analysis::Constant* x = constants[1];
1127 const analysis::Constant* min_val = constants[2];
1128
1129 if (x == nullptr || min_val == nullptr) {
1130 return nullptr;
1131 }
1132
1133 const analysis::Constant* temp =
1134 FoldFPBinaryOp(FoldMax, inst->type_id(), {x, min_val}, context);
1135 if (temp == min_val) {
1136 // We can assume that |min_val| is less than |max_val|. Therefore, if the
1137 // result of the max operation is |min_val|, we know the result of the min
1138 // operation, even if |max_val| is not a constant.
1139 return min_val;
1140 }
1141 return nullptr;
1142}
1143
1144// Fold a clamp instruction when |x >= max_val|.
1145const analysis::Constant* FoldClamp3(
1146 IRContext* context, Instruction* inst,
1147 const std::vector<const analysis::Constant*>& constants) {
1148 assert(inst->opcode() == SpvOpExtInst &&
1149 "Expecting an extended instruction.");
1150 assert(inst->GetSingleWordInOperand(0) ==
1151 context->get_feature_mgr()->GetExtInstImportId_GLSLstd450() &&
1152 "Expecting a GLSLstd450 extended instruction.");
1153
1154 const analysis::Constant* x = constants[1];
1155 const analysis::Constant* max_val = constants[3];
1156
1157 if (x == nullptr || max_val == nullptr) {
1158 return nullptr;
1159 }
1160
1161 const analysis::Constant* temp =
1162 FoldFPBinaryOp(FoldMin, inst->type_id(), {x, max_val}, context);
1163 if (temp == max_val) {
1164 // We can assume that |min_val| is less than |max_val|. Therefore, if the
1165 // result of the max operation is |min_val|, we know the result of the min
1166 // operation, even if |max_val| is not a constant.
1167 return max_val;
1168 }
1169 return nullptr;
1170}
1171
Ben Claytondc6b76a2020-02-24 14:53:40 +00001172UnaryScalarFoldingRule FoldFTranscendentalUnary(double (*fp)(double)) {
1173 return
1174 [fp](const analysis::Type* result_type, const analysis::Constant* a,
1175 analysis::ConstantManager* const_mgr) -> const analysis::Constant* {
1176 assert(result_type != nullptr && a != nullptr);
1177 const analysis::Float* float_type = a->type()->AsFloat();
1178 assert(float_type != nullptr);
1179 assert(float_type == result_type->AsFloat());
1180 if (float_type->width() == 32) {
1181 float fa = a->GetFloat();
1182 float res = static_cast<float>(fp(fa));
1183 utils::FloatProxy<float> result(res);
1184 std::vector<uint32_t> words = result.GetWords();
1185 return const_mgr->GetConstant(result_type, words);
1186 } else if (float_type->width() == 64) {
1187 double fa = a->GetDouble();
1188 double res = fp(fa);
1189 utils::FloatProxy<double> result(res);
1190 std::vector<uint32_t> words = result.GetWords();
1191 return const_mgr->GetConstant(result_type, words);
1192 }
1193 return nullptr;
1194 };
1195}
1196
1197BinaryScalarFoldingRule FoldFTranscendentalBinary(double (*fp)(double,
1198 double)) {
1199 return
1200 [fp](const analysis::Type* result_type, const analysis::Constant* a,
1201 const analysis::Constant* b,
1202 analysis::ConstantManager* const_mgr) -> const analysis::Constant* {
1203 assert(result_type != nullptr && a != nullptr);
1204 const analysis::Float* float_type = a->type()->AsFloat();
1205 assert(float_type != nullptr);
1206 assert(float_type == result_type->AsFloat());
1207 assert(float_type == b->type()->AsFloat());
1208 if (float_type->width() == 32) {
1209 float fa = a->GetFloat();
1210 float fb = b->GetFloat();
1211 float res = static_cast<float>(fp(fa, fb));
1212 utils::FloatProxy<float> result(res);
1213 std::vector<uint32_t> words = result.GetWords();
1214 return const_mgr->GetConstant(result_type, words);
1215 } else if (float_type->width() == 64) {
1216 double fa = a->GetDouble();
1217 double fb = b->GetDouble();
1218 double res = fp(fa, fb);
1219 utils::FloatProxy<double> result(res);
1220 std::vector<uint32_t> words = result.GetWords();
1221 return const_mgr->GetConstant(result_type, words);
1222 }
1223 return nullptr;
1224 };
1225}
Chris Forbescc5697f2019-01-30 11:54:08 -08001226} // namespace
1227
Ben Claytond0f684e2019-08-30 22:36:08 +01001228void ConstantFoldingRules::AddFoldingRules() {
Chris Forbescc5697f2019-01-30 11:54:08 -08001229 // Add all folding rules to the list for the opcodes to which they apply.
1230 // Note that the order in which rules are added to the list matters. If a rule
1231 // applies to the instruction, the rest of the rules will not be attempted.
1232 // Take that into consideration.
1233
1234 rules_[SpvOpCompositeConstruct].push_back(FoldCompositeWithConstants());
1235
1236 rules_[SpvOpCompositeExtract].push_back(FoldExtractWithConstants());
1237
1238 rules_[SpvOpConvertFToS].push_back(FoldFToI());
1239 rules_[SpvOpConvertFToU].push_back(FoldFToI());
1240 rules_[SpvOpConvertSToF].push_back(FoldIToF());
1241 rules_[SpvOpConvertUToF].push_back(FoldIToF());
1242
1243 rules_[SpvOpDot].push_back(FoldOpDotWithConstants());
1244 rules_[SpvOpFAdd].push_back(FoldFAdd());
1245 rules_[SpvOpFDiv].push_back(FoldFDiv());
1246 rules_[SpvOpFMul].push_back(FoldFMul());
1247 rules_[SpvOpFSub].push_back(FoldFSub());
1248
1249 rules_[SpvOpFOrdEqual].push_back(FoldFOrdEqual());
1250
1251 rules_[SpvOpFUnordEqual].push_back(FoldFUnordEqual());
1252
1253 rules_[SpvOpFOrdNotEqual].push_back(FoldFOrdNotEqual());
1254
1255 rules_[SpvOpFUnordNotEqual].push_back(FoldFUnordNotEqual());
1256
1257 rules_[SpvOpFOrdLessThan].push_back(FoldFOrdLessThan());
1258 rules_[SpvOpFOrdLessThan].push_back(
1259 FoldFClampFeedingCompare(SpvOpFOrdLessThan));
1260
1261 rules_[SpvOpFUnordLessThan].push_back(FoldFUnordLessThan());
1262 rules_[SpvOpFUnordLessThan].push_back(
1263 FoldFClampFeedingCompare(SpvOpFUnordLessThan));
1264
1265 rules_[SpvOpFOrdGreaterThan].push_back(FoldFOrdGreaterThan());
1266 rules_[SpvOpFOrdGreaterThan].push_back(
1267 FoldFClampFeedingCompare(SpvOpFOrdGreaterThan));
1268
1269 rules_[SpvOpFUnordGreaterThan].push_back(FoldFUnordGreaterThan());
1270 rules_[SpvOpFUnordGreaterThan].push_back(
1271 FoldFClampFeedingCompare(SpvOpFUnordGreaterThan));
1272
1273 rules_[SpvOpFOrdLessThanEqual].push_back(FoldFOrdLessThanEqual());
1274 rules_[SpvOpFOrdLessThanEqual].push_back(
1275 FoldFClampFeedingCompare(SpvOpFOrdLessThanEqual));
1276
1277 rules_[SpvOpFUnordLessThanEqual].push_back(FoldFUnordLessThanEqual());
1278 rules_[SpvOpFUnordLessThanEqual].push_back(
1279 FoldFClampFeedingCompare(SpvOpFUnordLessThanEqual));
1280
1281 rules_[SpvOpFOrdGreaterThanEqual].push_back(FoldFOrdGreaterThanEqual());
1282 rules_[SpvOpFOrdGreaterThanEqual].push_back(
1283 FoldFClampFeedingCompare(SpvOpFOrdGreaterThanEqual));
1284
1285 rules_[SpvOpFUnordGreaterThanEqual].push_back(FoldFUnordGreaterThanEqual());
1286 rules_[SpvOpFUnordGreaterThanEqual].push_back(
1287 FoldFClampFeedingCompare(SpvOpFUnordGreaterThanEqual));
1288
1289 rules_[SpvOpVectorShuffle].push_back(FoldVectorShuffleWithConstants());
1290 rules_[SpvOpVectorTimesScalar].push_back(FoldVectorTimesScalar());
1291
1292 rules_[SpvOpFNegate].push_back(FoldFNegate());
Ben Claytonb73b7602019-07-29 13:56:13 +01001293 rules_[SpvOpQuantizeToF16].push_back(FoldQuantizeToF16());
Ben Claytond0f684e2019-08-30 22:36:08 +01001294
1295 // Add rules for GLSLstd450
1296 FeatureManager* feature_manager = context_->get_feature_mgr();
1297 uint32_t ext_inst_glslstd450_id =
1298 feature_manager->GetExtInstImportId_GLSLstd450();
1299 if (ext_inst_glslstd450_id != 0) {
1300 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450FMix}].push_back(FoldFMix());
Ben Claytond552f632019-11-18 11:18:41 +00001301 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450SMin}].push_back(
1302 FoldFPBinaryOp(FoldMin));
1303 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450UMin}].push_back(
1304 FoldFPBinaryOp(FoldMin));
1305 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450FMin}].push_back(
1306 FoldFPBinaryOp(FoldMin));
1307 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450SMax}].push_back(
1308 FoldFPBinaryOp(FoldMax));
1309 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450UMax}].push_back(
1310 FoldFPBinaryOp(FoldMax));
1311 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450FMax}].push_back(
1312 FoldFPBinaryOp(FoldMax));
1313 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450UClamp}].push_back(
1314 FoldClamp1);
1315 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450UClamp}].push_back(
1316 FoldClamp2);
1317 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450UClamp}].push_back(
1318 FoldClamp3);
1319 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450SClamp}].push_back(
1320 FoldClamp1);
1321 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450SClamp}].push_back(
1322 FoldClamp2);
1323 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450SClamp}].push_back(
1324 FoldClamp3);
1325 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450FClamp}].push_back(
1326 FoldClamp1);
1327 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450FClamp}].push_back(
1328 FoldClamp2);
1329 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450FClamp}].push_back(
1330 FoldClamp3);
Ben Claytondc6b76a2020-02-24 14:53:40 +00001331 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Sin}].push_back(
1332 FoldFPUnaryOp(FoldFTranscendentalUnary(std::sin)));
1333 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Cos}].push_back(
1334 FoldFPUnaryOp(FoldFTranscendentalUnary(std::cos)));
1335 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Tan}].push_back(
1336 FoldFPUnaryOp(FoldFTranscendentalUnary(std::tan)));
1337 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Asin}].push_back(
1338 FoldFPUnaryOp(FoldFTranscendentalUnary(std::asin)));
1339 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Acos}].push_back(
1340 FoldFPUnaryOp(FoldFTranscendentalUnary(std::acos)));
1341 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Atan}].push_back(
1342 FoldFPUnaryOp(FoldFTranscendentalUnary(std::atan)));
1343 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Exp}].push_back(
1344 FoldFPUnaryOp(FoldFTranscendentalUnary(std::exp)));
1345 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Log}].push_back(
1346 FoldFPUnaryOp(FoldFTranscendentalUnary(std::log)));
1347
1348#ifdef __ANDROID__
1349 // Android NDK r15c tageting ABI 15 doesn't have full support for C++11
1350 // (no std::exp2/log2). ::exp2 is available from C99 but ::log2 isn't
1351 // available up until ABI 18 so we use a shim
1352 auto log2_shim = [](double v) -> double { return log(v) / log(2.0); };
1353 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Exp2}].push_back(
1354 FoldFPUnaryOp(FoldFTranscendentalUnary(::exp2)));
1355 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Log2}].push_back(
1356 FoldFPUnaryOp(FoldFTranscendentalUnary(log2_shim)));
1357#else
1358 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Exp2}].push_back(
1359 FoldFPUnaryOp(FoldFTranscendentalUnary(std::exp2)));
1360 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Log2}].push_back(
1361 FoldFPUnaryOp(FoldFTranscendentalUnary(std::log2)));
1362#endif
1363
1364 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Sqrt}].push_back(
1365 FoldFPUnaryOp(FoldFTranscendentalUnary(std::sqrt)));
1366 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Atan2}].push_back(
1367 FoldFPBinaryOp(FoldFTranscendentalBinary(std::atan2)));
1368 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Pow}].push_back(
1369 FoldFPBinaryOp(FoldFTranscendentalBinary(std::pow)));
Ben Claytond0f684e2019-08-30 22:36:08 +01001370 }
Chris Forbescc5697f2019-01-30 11:54:08 -08001371}
1372} // namespace opt
1373} // namespace spvtools