Simplify the type adjustment in the IslExprBuilder
We now have a simple function to adjust/unify the types of two (or three)
operands before an operation that requieres the same type for all operands.
Due to this change we will not promote parameters that are added to i64
anymore if that is not needed.
llvm-svn: 271513
diff --git a/polly/lib/CodeGen/IslExprBuilder.cpp b/polly/lib/CodeGen/IslExprBuilder.cpp
index 2570ed9..60e1624 100644
--- a/polly/lib/CodeGen/IslExprBuilder.cpp
+++ b/polly/lib/CodeGen/IslExprBuilder.cpp
@@ -74,6 +74,8 @@
Value *IslExprBuilder::createBinOp(BinaryOperator::BinaryOps Opc, Value *LHS,
Value *RHS, const Twine &Name) {
+ unifyTypes(LHS, RHS);
+
// Handle the plain operation (without overflow tracking) first.
if (!OverflowState) {
switch (Opc) {
@@ -137,7 +139,7 @@
return createBinOp(Instruction::Mul, LHS, RHS, Name);
}
-Type *IslExprBuilder::getWidestType(Type *T1, Type *T2) {
+static Type *getWidestType(Type *T1, Type *T2) {
assert(isa<IntegerType>(T1) && isa<IntegerType>(T2));
if (T1->getPrimitiveSizeInBits() < T2->getPrimitiveSizeInBits())
@@ -146,23 +148,29 @@
return T1;
}
+void IslExprBuilder::unifyTypes(Value *&V0, Value *&V1, Value *&V2) {
+ auto *T0 = V0->getType();
+ auto *T1 = V1->getType();
+ auto *T2 = V2->getType();
+ if (T0 == T1 && T1 == T2)
+ return;
+ auto *MaxT = getWidestType(T0, T1);
+ MaxT = getWidestType(MaxT, T2);
+ V0 = Builder.CreateSExt(V0, MaxT);
+ V1 = Builder.CreateSExt(V1, MaxT);
+ V2 = Builder.CreateSExt(V2, MaxT);
+}
+
Value *IslExprBuilder::createOpUnary(__isl_take isl_ast_expr *Expr) {
assert(isl_ast_expr_get_op_type(Expr) == isl_ast_op_minus &&
"Unsupported unary operation");
- Value *V;
- Type *MaxType = getType(Expr);
- assert(MaxType->isIntegerTy() &&
+ auto *V = create(isl_ast_expr_get_op_arg(Expr, 0));
+ assert(V->getType()->isIntegerTy() &&
"Unary expressions can only be created for integer types");
- V = create(isl_ast_expr_get_op_arg(Expr, 0));
- MaxType = getWidestType(MaxType, V->getType());
-
- if (MaxType != V->getType())
- V = Builder.CreateSExt(V, MaxType);
-
isl_ast_expr_free(Expr);
- return createSub(ConstantInt::getNullValue(MaxType), V);
+ return createSub(ConstantInt::getNullValue(V->getType()), V);
}
Value *IslExprBuilder::createOpNAry(__isl_take isl_ast_expr *Expr) {
@@ -179,13 +187,7 @@
Value *OpV;
OpV = create(isl_ast_expr_get_op_arg(Expr, i));
- Type *Ty = getWidestType(V->getType(), OpV->getType());
-
- if (Ty != OpV->getType())
- OpV = Builder.CreateSExt(OpV, Ty);
-
- if (Ty != V->getType())
- V = Builder.CreateSExt(V, Ty);
+ unifyTypes(V, OpV);
switch (isl_ast_expr_get_op_type(Expr)) {
default:
@@ -250,18 +252,8 @@
assert(NextIndex->getType()->isIntegerTy() &&
"Access index should be an integer");
- if (!IndexOp) {
- IndexOp = NextIndex;
- } else {
- Type *Ty = getWidestType(NextIndex->getType(), IndexOp->getType());
-
- if (Ty != NextIndex->getType())
- NextIndex = Builder.CreateIntCast(NextIndex, Ty, true);
- if (Ty != IndexOp->getType())
- IndexOp = Builder.CreateIntCast(IndexOp, Ty, true);
-
- IndexOp = createAdd(IndexOp, NextIndex, "polly.access.add." + BaseName);
- }
+ IndexOp = !IndexOp ? NextIndex : createAdd(IndexOp, NextIndex,
+ "polly.access.add." + BaseName);
// For every but the last dimension multiply the size, for the last
// dimension we can exit the loop.
@@ -276,14 +268,6 @@
expandCodeFor(S, SE, DL, "polly", DimSCEV, DimSCEV->getType(),
&*Builder.GetInsertPoint());
- Type *Ty = getWidestType(DimSize->getType(), IndexOp->getType());
-
- if (Ty != IndexOp->getType())
- IndexOp = Builder.CreateSExtOrTrunc(IndexOp, Ty,
- "polly.access.sext." + BaseName);
- if (Ty != DimSize->getType())
- DimSize = Builder.CreateSExtOrTrunc(DimSize, Ty,
- "polly.access.sext." + BaseName);
IndexOp = createMul(IndexOp, DimSize, "polly.access.mul." + BaseName);
}
@@ -301,7 +285,6 @@
Value *IslExprBuilder::createOpBin(__isl_take isl_ast_expr *Expr) {
Value *LHS, *RHS, *Res;
- Type *MaxType;
isl_ast_op_type OpType;
assert(isl_ast_expr_get_type(Expr) == isl_ast_expr_op &&
@@ -314,41 +297,25 @@
LHS = create(isl_ast_expr_get_op_arg(Expr, 0));
RHS = create(isl_ast_expr_get_op_arg(Expr, 1));
- Type *LHSType = LHS->getType();
- Type *RHSType = RHS->getType();
-
- MaxType = getWidestType(LHSType, RHSType);
-
- // Take the result into account when calculating the widest type.
- //
- // For operations such as '+' the result may require a type larger than
- // the type of the individual operands. For other operations such as '/', the
- // result type cannot be larger than the type of the individual operand. isl
- // does not calculate correct types for these operations and we consequently
- // exclude those operations here.
+ // For possibly overflowing operations we will later adjust types but
+ // for others we do it now as we will directly create the operations.
switch (OpType) {
case isl_ast_op_pdiv_q:
case isl_ast_op_pdiv_r:
case isl_ast_op_div:
case isl_ast_op_fdiv_q:
case isl_ast_op_zdiv_r:
- // Do nothing
+ unifyTypes(LHS, RHS);
break;
case isl_ast_op_add:
case isl_ast_op_sub:
case isl_ast_op_mul:
- MaxType = getWidestType(MaxType, getType(Expr));
+ // Do nothing
break;
default:
llvm_unreachable("This is no binary isl ast expression");
}
- if (MaxType != RHS->getType())
- RHS = Builder.CreateSExt(RHS, MaxType);
-
- if (MaxType != LHS->getType())
- LHS = Builder.CreateSExt(LHS, MaxType);
-
switch (OpType) {
default:
llvm_unreachable("This is no binary isl ast expression");
@@ -379,13 +346,15 @@
// incorrect overflow in some bordercases.
//
// floord(n,d) ((n < 0) ? (n - d + 1) : n) / d
- Value *One = ConstantInt::get(MaxType, 1);
- Value *Zero = ConstantInt::get(MaxType, 0);
Value *Sum1 = createSub(LHS, RHS, "pexp.fdiv_q.0");
+ Value *One = ConstantInt::get(Sum1->getType(), 1);
Value *Sum2 = createAdd(Sum1, One, "pexp.fdiv_q.1");
+ Value *Zero = ConstantInt::get(LHS->getType(), 0);
Value *isNegative = Builder.CreateICmpSLT(LHS, Zero, "pexp.fdiv_q.2");
+ unifyTypes(LHS, Sum2);
Value *Dividend =
Builder.CreateSelect(isNegative, Sum2, LHS, "pexp.fdiv_q.3");
+ unifyTypes(Dividend, RHS);
Res = Builder.CreateSDiv(Dividend, RHS, "pexp.fdiv_q.4");
break;
}
@@ -410,7 +379,6 @@
assert(isl_ast_expr_get_op_type(Expr) == isl_ast_op_select &&
"Unsupported unary isl ast expression");
Value *LHS, *RHS, *Cond;
- Type *MaxType = getType(Expr);
Cond = create(isl_ast_expr_get_op_arg(Expr, 0));
if (!Cond->getType()->isIntegerTy(1))
@@ -418,17 +386,8 @@
LHS = create(isl_ast_expr_get_op_arg(Expr, 1));
RHS = create(isl_ast_expr_get_op_arg(Expr, 2));
+ unifyTypes(LHS, RHS);
- MaxType = getWidestType(MaxType, LHS->getType());
- MaxType = getWidestType(MaxType, RHS->getType());
-
- if (MaxType != RHS->getType())
- RHS = Builder.CreateSExt(RHS, MaxType);
-
- if (MaxType != LHS->getType())
- LHS = Builder.CreateSExt(LHS, MaxType);
-
- // TODO: Do we want to truncate the result?
isl_ast_expr_free(Expr);
return Builder.CreateSelect(Cond, LHS, RHS);
}
@@ -461,16 +420,7 @@
if (RHSTy->isPointerTy())
RHS = Builder.CreatePtrToInt(RHS, PtrAsIntTy);
- if (LHS->getType() != RHS->getType()) {
- Type *MaxType = LHS->getType();
- MaxType = getWidestType(MaxType, RHS->getType());
-
- if (MaxType != RHS->getType())
- RHS = Builder.CreateSExt(RHS, MaxType);
-
- if (MaxType != LHS->getType())
- LHS = Builder.CreateSExt(LHS, MaxType);
- }
+ unifyTypes(LHS, RHS);
isl_ast_op_type OpType = isl_ast_expr_get_op_type(Expr);
assert(OpType >= isl_ast_op_eq && OpType <= isl_ast_op_gt &&
diff --git a/polly/lib/CodeGen/IslNodeBuilder.cpp b/polly/lib/CodeGen/IslNodeBuilder.cpp
index 8ef6476..90b27d4 100644
--- a/polly/lib/CodeGen/IslNodeBuilder.cpp
+++ b/polly/lib/CodeGen/IslNodeBuilder.cpp
@@ -390,14 +390,7 @@
Value *ValueLB = ExprBuilder.create(Init);
Value *ValueInc = ExprBuilder.create(Inc);
- Type *MaxType = ExprBuilder.getType(Iterator);
- MaxType = ExprBuilder.getWidestType(MaxType, ValueLB->getType());
- MaxType = ExprBuilder.getWidestType(MaxType, ValueInc->getType());
-
- if (MaxType != ValueLB->getType())
- ValueLB = Builder.CreateSExt(ValueLB, MaxType);
- if (MaxType != ValueInc->getType())
- ValueInc = Builder.CreateSExt(ValueInc, MaxType);
+ ExprBuilder.unifyTypes(ValueLB, ValueInc);
std::vector<Value *> IVS(VectorWidth);
IVS[0] = ValueLB;
@@ -472,17 +465,7 @@
ValueUB = ExprBuilder.create(UB);
ValueInc = ExprBuilder.create(Inc);
- MaxType = ExprBuilder.getType(Iterator);
- MaxType = ExprBuilder.getWidestType(MaxType, ValueLB->getType());
- MaxType = ExprBuilder.getWidestType(MaxType, ValueUB->getType());
- MaxType = ExprBuilder.getWidestType(MaxType, ValueInc->getType());
-
- if (MaxType != ValueLB->getType())
- ValueLB = Builder.CreateSExt(ValueLB, MaxType);
- if (MaxType != ValueUB->getType())
- ValueUB = Builder.CreateSExt(ValueUB, MaxType);
- if (MaxType != ValueInc->getType())
- ValueInc = Builder.CreateSExt(ValueInc, MaxType);
+ ExprBuilder.unifyTypes(ValueLB, ValueUB, ValueInc);
// If we can show that LB <Predicate> UB holds at least once, we can
// omit the GuardBB in front of the loop.
@@ -583,17 +566,7 @@
ValueUB = Builder.CreateAdd(
ValueUB, Builder.CreateSExt(Builder.getTrue(), ValueUB->getType()));
- MaxType = ExprBuilder.getType(Iterator);
- MaxType = ExprBuilder.getWidestType(MaxType, ValueLB->getType());
- MaxType = ExprBuilder.getWidestType(MaxType, ValueUB->getType());
- MaxType = ExprBuilder.getWidestType(MaxType, ValueInc->getType());
-
- if (MaxType != ValueLB->getType())
- ValueLB = Builder.CreateSExt(ValueLB, MaxType);
- if (MaxType != ValueUB->getType())
- ValueUB = Builder.CreateSExt(ValueUB, MaxType);
- if (MaxType != ValueInc->getType())
- ValueInc = Builder.CreateSExt(ValueInc, MaxType);
+ ExprBuilder.unifyTypes(ValueLB, ValueUB, ValueInc);
BasicBlock::iterator LoopBody;
diff --git a/polly/lib/CodeGen/LoopGenerators.cpp b/polly/lib/CodeGen/LoopGenerators.cpp
index 566e460..960e0cc 100644
--- a/polly/lib/CodeGen/LoopGenerators.cpp
+++ b/polly/lib/CodeGen/LoopGenerators.cpp
@@ -61,6 +61,8 @@
assert(LB->getType() == UB->getType() && "Types of loop bounds do not match");
IntegerType *LoopIVType = dyn_cast<IntegerType>(UB->getType());
assert(LoopIVType && "UB is not integer?");
+ assert((LoopIVType == LB->getType() && LoopIVType == Stride->getType()) &&
+ "LB, UB and Stride should have equal types.");
BasicBlock *BeforeBB = Builder.GetInsertBlock();
BasicBlock *GuardBB =
@@ -121,7 +123,6 @@
Builder.SetInsertPoint(HeaderBB);
PHINode *IV = Builder.CreatePHI(LoopIVType, 2, "polly.indvar");
IV->addIncoming(LB, PreHeaderBB);
- Stride = Builder.CreateZExtOrBitCast(Stride, LoopIVType);
Value *IncrementedIV = Builder.CreateNSWAdd(IV, Stride, "polly.indvar_next");
Value *LoopCondition;
UB = Builder.CreateSub(UB, Stride, "polly.adjust_ub");
@@ -147,6 +148,12 @@
Value *ParallelLoopGenerator::createParallelLoop(
Value *LB, Value *UB, Value *Stride, SetVector<Value *> &UsedValues,
ValueMapT &Map, BasicBlock::iterator *LoopBody) {
+
+ // Adjust the types to match the GOMP API.
+ LB = Builder.CreateSExt(LB, LongType);
+ UB = Builder.CreateSExt(UB, LongType);
+ Stride = Builder.CreateSExt(Stride, LongType);
+
Function *SubFn;
AllocaInst *Struct = storeValuesIntoStruct(UsedValues);