Added type checking for Mod and bitwise ops

Added type checking and bitwise logic in a
few cases for Mod and bitwise operators for
TIntermediate, TIntermBinary and
TIntermConstantUnion.

Change-Id: Ic6ac624fd8d6d9f407f1b8fac40ae31f54a6c7da
Reviewed-on: https://swiftshader-review.googlesource.com/3113
Tested-by: Alexis Hétu <sugoi@google.com>
Reviewed-by: Nicolas Capens <capn@google.com>
diff --git a/src/OpenGL/compiler/Intermediate.cpp b/src/OpenGL/compiler/Intermediate.cpp
index 7f056fb..fad11c6 100644
--- a/src/OpenGL/compiler/Intermediate.cpp
+++ b/src/OpenGL/compiler/Intermediate.cpp
@@ -177,12 +177,27 @@
                 return 0;
             }
             break;
+        case EOpBitwiseOr:
+        case EOpBitwiseXor:
+        case EOpBitwiseAnd:
+            if ((left->getBasicType() != EbtInt && left->getBasicType() != EbtUInt) || left->isMatrix() || left->isArray()) {
+                return 0;
+            }
+            break;
         case EOpAdd:
         case EOpSub:
         case EOpDiv:
         case EOpMul:
-            if (left->getBasicType() == EbtStruct || left->getBasicType() == EbtBool)
+            if (left->getBasicType() == EbtStruct || left->getBasicType() == EbtBool) {
                 return 0;
+            }
+            break;
+        case EOpIMod:
+            // Note that this is only for the % operator, not for mod()
+            if (left->getBasicType() == EbtStruct || left->getBasicType() == EbtBool || left->getBasicType() == EbtFloat) {
+                return 0;
+            }
+            break;
         default: break;
     }
 
@@ -285,6 +300,12 @@
     }
 
     switch (op) {
+        case EOpBitwiseNot:
+            if ((child->getType().getBasicType() != EbtInt && child->getType().getBasicType() != EbtUInt) || child->getType().isMatrix() || child->getType().isArray()) {
+                return 0;
+            }
+            break;
+
         case EOpLogicalNot:
             if (child->getType().getBasicType() != EbtBool || child->getType().isMatrix() || child->getType().isArray() || child->getType().isVector()) {
                 return 0;
@@ -659,6 +680,10 @@
             if (operand->getBasicType() != EbtBool)
                 return false;
             break;
+        case EOpBitwiseNot:
+            if(operand->getBasicType() != EbtInt && operand->getBasicType() != EbtUInt)
+                return false;
+            break;
         case EOpNegative:
         case EOpPostIncrement:
         case EOpPostDecrement:
@@ -726,13 +751,13 @@
         getTypePointer()->setQualifier(EvqTemporary);
     }
 
-    int size = std::max(left->getNominalSize(), right->getNominalSize());
-	int matrixSize = std::max(left->getSecondarySize(), right->getSecondarySize()); // FIXME: This will have to change for NxM matrices
+    int primarySize = std::max(left->getNominalSize(), right->getNominalSize());
+	int secondarySize = std::max(left->getSecondarySize(), right->getSecondarySize());
 
     //
     // All scalars. Code after this test assumes this case is removed!
     //
-    if (size == 1) {
+    if (primarySize == 1) {
         switch (op) {
             //
             // Promote to conditional
@@ -751,6 +776,7 @@
             //
             case EOpLogicalAnd:
             case EOpLogicalOr:
+            case EOpLogicalXor:
                 // Both operands must be of type bool.
                 if (left->getBasicType() != EbtBool || right->getBasicType() != EbtBool)
                     return false;
@@ -794,12 +820,12 @@
                     op = EOpVectorTimesMatrix;
                 else {
                     op = EOpMatrixTimesScalar;
-                    setType(TType(basicType, higherPrecision, EvqTemporary, size, matrixSize));
+                    setType(TType(basicType, higherPrecision, EvqTemporary, primarySize, secondarySize));
                 }
             } else if (left->isMatrix() && !right->isMatrix()) {
                 if (right->isVector()) {
                     op = EOpMatrixTimesVector;
-                    setType(TType(basicType, higherPrecision, EvqTemporary, size, 1));
+                    setType(TType(basicType, higherPrecision, EvqTemporary, primarySize, 1));
                 } else {
                     op = EOpMatrixTimesScalar;
                 }
@@ -810,7 +836,7 @@
                     // leave as component product
                 } else if (left->isVector() || right->isVector()) {
                     op = EOpVectorTimesScalar;
-                    setType(TType(basicType, higherPrecision, EvqTemporary, size, 1));
+                    setType(TType(basicType, higherPrecision, EvqTemporary, primarySize, 1));
                 }
             } else {
                 infoSink.info.message(EPrefixInternalError, "Missing elses", getLine());
@@ -839,7 +865,7 @@
                     if (! left->isVector())
                         return false;
                     op = EOpVectorTimesScalarAssign;
-                    setType(TType(basicType, higherPrecision, EvqTemporary, size, 1));
+                    setType(TType(basicType, higherPrecision, EvqTemporary, primarySize, 1));
                 }
             } else {
                 infoSink.info.message(EPrefixInternalError, "Missing elses", getLine());
@@ -852,6 +878,12 @@
         case EOpAdd:
         case EOpSub:
         case EOpDiv:
+        case EOpIMod:
+        case EOpBitShiftLeft:
+        case EOpBitShiftRight:
+        case EOpBitwiseAnd:
+        case EOpBitwiseXor:
+        case EOpBitwiseOr:
         case EOpAddAssign:
         case EOpSubAssign:
         case EOpDivAssign:
@@ -864,7 +896,35 @@
             if ((left->isMatrix() && right->isVector()) ||
                 (left->isVector() && right->isMatrix()))
                 return false;
-            setType(TType(basicType, higherPrecision, EvqTemporary, size, matrixSize));
+
+            // Are the sizes compatible?
+            if(left->getNominalSize() != right->getNominalSize() ||
+               left->getSecondarySize() != right->getSecondarySize())
+            {
+                // If the nominal sizes of operands do not match:
+                // One of them must be a scalar.
+                if(!left->isScalar() && !right->isScalar())
+                    return false;
+
+                // In the case of compound assignment other than multiply-assign,
+                // the right side needs to be a scalar. Otherwise a vector/matrix
+                // would be assigned to a scalar. A scalar can't be shifted by a
+                // vector either.
+                if(!right->isScalar() && (modifiesState() || op == EOpBitShiftLeft || op == EOpBitShiftRight))
+                    return false;
+            }
+
+            {
+                setType(TType(basicType, higherPrecision, EvqTemporary,
+                    static_cast<unsigned char>(primarySize), static_cast<unsigned char>(secondarySize)));
+                if(left->isArray())
+                {
+                    ASSERT(left->getArraySize() == right->getArraySize());
+                    type.setArraySize(left->getArraySize());
+                }
+            }
+
+            setType(TType(basicType, higherPrecision, EvqTemporary, primarySize, secondarySize));
             break;
 
         case EOpEqual:
@@ -1009,35 +1069,50 @@
                 }
                 break;
             case EOpDiv:
+            case EOpIMod:
                 tempConstArray = new ConstantUnion[objectSize];
                 {// support MSVC++6.0
                     for (int i = 0; i < objectSize; i++) {
                         switch (getType().getBasicType()) {
-            case EbtFloat:
-                if (rightUnionArray[i] == 0.0f) {
-                    infoSink.info.message(EPrefixWarning, "Divide by zero error during constant folding", getLine());
-                    tempConstArray[i].setFConst(FLT_MAX);
-                } else
-                    tempConstArray[i].setFConst(unionArray[i].getFConst() / rightUnionArray[i].getFConst());
-                break;
+                            case EbtFloat:
+                                if (rightUnionArray[i] == 0.0f) {
+                                    infoSink.info.message(EPrefixWarning, "Divide by zero error during constant folding", getLine());
+                                    tempConstArray[i].setFConst(FLT_MAX);
+                                } else {
+                                    ASSERT(op == EOpDiv);
+                                    tempConstArray[i].setFConst(unionArray[i].getFConst() / rightUnionArray[i].getFConst());
+                                }
+                                break;
 
-            case EbtInt:
-                if (rightUnionArray[i] == 0) {
-                    infoSink.info.message(EPrefixWarning, "Divide by zero error during constant folding", getLine());
-                    tempConstArray[i].setIConst(INT_MAX);
-                } else
-                    tempConstArray[i].setIConst(unionArray[i].getIConst() / rightUnionArray[i].getIConst());
-                break;
-            case EbtUInt:
-                if (rightUnionArray[i] == 0) {
-                    infoSink.info.message(EPrefixWarning, "Divide by zero error during constant folding", getLine());
-                    tempConstArray[i].setUConst(UINT_MAX);
-                } else
-                    tempConstArray[i].setUConst(unionArray[i].getUConst() / rightUnionArray[i].getUConst());
-                break;
-            default:
-                infoSink.info.message(EPrefixInternalError, "Constant folding cannot be done for \"/\"", getLine());
-                return 0;
+                            case EbtInt:
+                                if (rightUnionArray[i] == 0) {
+                                    infoSink.info.message(EPrefixWarning, "Divide by zero error during constant folding", getLine());
+                                    tempConstArray[i].setIConst(INT_MAX);
+                                } else {
+                                    if(op == EOpDiv) {
+                                        tempConstArray[i].setIConst(unionArray[i].getIConst() / rightUnionArray[i].getIConst());
+                                    } else {
+                                        ASSERT(op == EOpIMod);
+                                        tempConstArray[i].setIConst(unionArray[i].getIConst() % rightUnionArray[i].getIConst());
+                                    }
+                                }
+                                break;
+                            case EbtUInt:
+                                if (rightUnionArray[i] == 0) {
+                                    infoSink.info.message(EPrefixWarning, "Divide by zero error during constant folding", getLine());
+                                    tempConstArray[i].setUConst(UINT_MAX);
+                                } else {
+                                    if(op == EOpDiv) {
+                                        tempConstArray[i].setUConst(unionArray[i].getUConst() / rightUnionArray[i].getUConst());
+                                    } else {
+                                        ASSERT(op == EOpIMod);
+                                        tempConstArray[i].setUConst(unionArray[i].getUConst() % rightUnionArray[i].getUConst());
+                                    }
+                                }
+                                break;
+                            default:
+                                infoSink.info.message(EPrefixInternalError, "Constant folding cannot be done for \"/\"", getLine());
+                                return 0;
                         }
                     }
                 }
@@ -1102,12 +1177,38 @@
                 {// support MSVC++6.0
                     for (int i = 0; i < objectSize; i++)
                         switch (getType().getBasicType()) {
-            case EbtBool: tempConstArray[i].setBConst((unionArray[i] == rightUnionArray[i]) ? false : true); break;
-            default: assert(false && "Default missing");
+                            case EbtBool: tempConstArray[i].setBConst((unionArray[i] == rightUnionArray[i]) ? false : true); break;
+                            default: assert(false && "Default missing");
                     }
                 }
                 break;
 
+            case EOpBitwiseAnd:
+                tempConstArray = new ConstantUnion[objectSize];
+                for(int i = 0; i < objectSize; i++)
+                    tempConstArray[i] = unionArray[i] & rightUnionArray[i];
+                break;
+            case EOpBitwiseXor:
+                tempConstArray = new ConstantUnion[objectSize];
+                for(int i = 0; i < objectSize; i++)
+                    tempConstArray[i] = unionArray[i] ^ rightUnionArray[i];
+                break;
+            case EOpBitwiseOr:
+                tempConstArray = new ConstantUnion[objectSize];
+                for(int i = 0; i < objectSize; i++)
+                    tempConstArray[i] = unionArray[i] | rightUnionArray[i];
+                break;
+            case EOpBitShiftLeft:
+                tempConstArray = new ConstantUnion[objectSize];
+                for(int i = 0; i < objectSize; i++)
+                    tempConstArray[i] = unionArray[i] << rightUnionArray[i];
+                break;
+            case EOpBitShiftRight:
+                tempConstArray = new ConstantUnion[objectSize];
+                for(int i = 0; i < objectSize; i++)
+                    tempConstArray[i] = unionArray[i] >> rightUnionArray[i];
+                break;
+
             case EOpLessThan:
                 assert(objectSize == 1);
                 tempConstArray = new ConstantUnion[1];
@@ -1226,6 +1327,15 @@
                             return 0;
                     }
                     break;
+                case EOpBitwiseNot:
+                    switch(getType().getBasicType()) {
+                        case EbtInt: tempConstArray[i].setIConst(~unionArray[i].getIConst()); break;
+                        case EbtUInt: tempConstArray[i].setUConst(~unionArray[i].getUConst()); break;
+                        default:
+                            infoSink.info.message(EPrefixInternalError, "Unary operation not folded into constant", getLine());
+                            return 0;
+                    }
+                    break;
                 default:
                     return 0;
             }