[Matrix] Add matrix type to Clang.
This patch adds a matrix type to Clang as described in the draft
specification in clang/docs/MatrixSupport.rst. It introduces a new option
-fenable-matrix, which can be used to enable the matrix support.
The patch adds new MatrixType and DependentSizedMatrixType types along
with the plumbing required. Loads of and stores to pointers to matrix
values are lowered to memory operations on 1-D IR arrays. After loading,
the loaded values are cast to a vector. This ensures matrix values use
the alignment of the element type, instead of LLVM's large vector
alignment.
The operators and builtins described in the draft spec will will be added in
follow-up patches.
Reviewers: martong, rsmith, Bigcheese, anemet, dexonsmith, rjmccall, aaron.ballman
Reviewed By: rjmccall
Differential Revision: https://reviews.llvm.org/D72281
diff --git a/clang/lib/AST/ASTContext.cpp b/clang/lib/AST/ASTContext.cpp
index 4ed0073..8f38581 100644
--- a/clang/lib/AST/ASTContext.cpp
+++ b/clang/lib/AST/ASTContext.cpp
@@ -1932,6 +1932,17 @@
break;
}
+ case Type::ConstantMatrix: {
+ const auto *MT = cast<ConstantMatrixType>(T);
+ TypeInfo ElementInfo = getTypeInfo(MT->getElementType());
+ // The internal layout of a matrix value is implementation defined.
+ // Initially be ABI compatible with arrays with respect to alignment and
+ // size.
+ Width = ElementInfo.Width * MT->getNumRows() * MT->getNumColumns();
+ Align = ElementInfo.Align;
+ break;
+ }
+
case Type::Builtin:
switch (cast<BuiltinType>(T)->getKind()) {
default: llvm_unreachable("Unknown builtin type!");
@@ -3362,6 +3373,8 @@
case Type::DependentVector:
case Type::ExtVector:
case Type::DependentSizedExtVector:
+ case Type::ConstantMatrix:
+ case Type::DependentSizedMatrix:
case Type::DependentAddressSpace:
case Type::ObjCObject:
case Type::ObjCInterface:
@@ -3775,6 +3788,78 @@
return QualType(New, 0);
}
+QualType ASTContext::getConstantMatrixType(QualType ElementTy, unsigned NumRows,
+ unsigned NumColumns) const {
+ llvm::FoldingSetNodeID ID;
+ ConstantMatrixType::Profile(ID, ElementTy, NumRows, NumColumns,
+ Type::ConstantMatrix);
+
+ assert(MatrixType::isValidElementType(ElementTy) &&
+ "need a valid element type");
+ assert(ConstantMatrixType::isDimensionValid(NumRows) &&
+ ConstantMatrixType::isDimensionValid(NumColumns) &&
+ "need valid matrix dimensions");
+ void *InsertPos = nullptr;
+ if (ConstantMatrixType *MTP = MatrixTypes.FindNodeOrInsertPos(ID, InsertPos))
+ return QualType(MTP, 0);
+
+ QualType Canonical;
+ if (!ElementTy.isCanonical()) {
+ Canonical =
+ getConstantMatrixType(getCanonicalType(ElementTy), NumRows, NumColumns);
+
+ ConstantMatrixType *NewIP = MatrixTypes.FindNodeOrInsertPos(ID, InsertPos);
+ assert(!NewIP && "Matrix type shouldn't already exist in the map");
+ (void)NewIP;
+ }
+
+ auto *New = new (*this, TypeAlignment)
+ ConstantMatrixType(ElementTy, NumRows, NumColumns, Canonical);
+ MatrixTypes.InsertNode(New, InsertPos);
+ Types.push_back(New);
+ return QualType(New, 0);
+}
+
+QualType ASTContext::getDependentSizedMatrixType(QualType ElementTy,
+ Expr *RowExpr,
+ Expr *ColumnExpr,
+ SourceLocation AttrLoc) const {
+ QualType CanonElementTy = getCanonicalType(ElementTy);
+ llvm::FoldingSetNodeID ID;
+ DependentSizedMatrixType::Profile(ID, *this, CanonElementTy, RowExpr,
+ ColumnExpr);
+
+ void *InsertPos = nullptr;
+ DependentSizedMatrixType *Canon =
+ DependentSizedMatrixTypes.FindNodeOrInsertPos(ID, InsertPos);
+
+ if (!Canon) {
+ Canon = new (*this, TypeAlignment) DependentSizedMatrixType(
+ *this, CanonElementTy, QualType(), RowExpr, ColumnExpr, AttrLoc);
+#ifndef NDEBUG
+ DependentSizedMatrixType *CanonCheck =
+ DependentSizedMatrixTypes.FindNodeOrInsertPos(ID, InsertPos);
+ assert(!CanonCheck && "Dependent-sized matrix canonical type broken");
+#endif
+ DependentSizedMatrixTypes.InsertNode(Canon, InsertPos);
+ Types.push_back(Canon);
+ }
+
+ // Already have a canonical version of the matrix type
+ //
+ // If it exactly matches the requested type, use it directly.
+ if (Canon->getElementType() == ElementTy && Canon->getRowExpr() == RowExpr &&
+ Canon->getRowExpr() == ColumnExpr)
+ return QualType(Canon, 0);
+
+ // Use Canon as the canonical type for newly-built type.
+ DependentSizedMatrixType *New = new (*this, TypeAlignment)
+ DependentSizedMatrixType(*this, ElementTy, QualType(Canon, 0), RowExpr,
+ ColumnExpr, AttrLoc);
+ Types.push_back(New);
+ return QualType(New, 0);
+}
+
QualType ASTContext::getDependentAddressSpaceType(QualType PointeeType,
Expr *AddrSpaceExpr,
SourceLocation AttrLoc) const {
@@ -7338,6 +7423,11 @@
*NotEncodedT = T;
return;
+ case Type::ConstantMatrix:
+ if (NotEncodedT)
+ *NotEncodedT = T;
+ return;
+
// We could see an undeduced auto type here during error recovery.
// Just ignore it.
case Type::Auto:
@@ -8217,6 +8307,16 @@
LHS->getNumElements() == RHS->getNumElements();
}
+/// areCompatMatrixTypes - Return true if the two specified matrix types are
+/// compatible.
+static bool areCompatMatrixTypes(const ConstantMatrixType *LHS,
+ const ConstantMatrixType *RHS) {
+ assert(LHS->isCanonicalUnqualified() && RHS->isCanonicalUnqualified());
+ return LHS->getElementType() == RHS->getElementType() &&
+ LHS->getNumRows() == RHS->getNumRows() &&
+ LHS->getNumColumns() == RHS->getNumColumns();
+}
+
bool ASTContext::areCompatibleVectorTypes(QualType FirstVec,
QualType SecondVec) {
assert(FirstVec->isVectorType() && "FirstVec should be a vector type");
@@ -9414,6 +9514,11 @@
RHSCan->castAs<VectorType>()))
return LHS;
return {};
+ case Type::ConstantMatrix:
+ if (areCompatMatrixTypes(LHSCan->castAs<ConstantMatrixType>(),
+ RHSCan->castAs<ConstantMatrixType>()))
+ return LHS;
+ return {};
case Type::ObjCObject: {
// Check if the types are assignment compatible.
// FIXME: This should be type compatibility, e.g. whether
diff --git a/clang/lib/AST/ASTStructuralEquivalence.cpp b/clang/lib/AST/ASTStructuralEquivalence.cpp
index c562830..8b5b244 100644
--- a/clang/lib/AST/ASTStructuralEquivalence.cpp
+++ b/clang/lib/AST/ASTStructuralEquivalence.cpp
@@ -617,6 +617,34 @@
break;
}
+ case Type::DependentSizedMatrix: {
+ const DependentSizedMatrixType *Mat1 = cast<DependentSizedMatrixType>(T1);
+ const DependentSizedMatrixType *Mat2 = cast<DependentSizedMatrixType>(T2);
+ // The element types, row and column expressions must be structurally
+ // equivalent.
+ if (!IsStructurallyEquivalent(Context, Mat1->getRowExpr(),
+ Mat2->getRowExpr()) ||
+ !IsStructurallyEquivalent(Context, Mat1->getColumnExpr(),
+ Mat2->getColumnExpr()) ||
+ !IsStructurallyEquivalent(Context, Mat1->getElementType(),
+ Mat2->getElementType()))
+ return false;
+ break;
+ }
+
+ case Type::ConstantMatrix: {
+ const ConstantMatrixType *Mat1 = cast<ConstantMatrixType>(T1);
+ const ConstantMatrixType *Mat2 = cast<ConstantMatrixType>(T2);
+ // The element types must be structurally equivalent and the number of rows
+ // and columns must match.
+ if (!IsStructurallyEquivalent(Context, Mat1->getElementType(),
+ Mat2->getElementType()) ||
+ Mat1->getNumRows() != Mat2->getNumRows() ||
+ Mat1->getNumColumns() != Mat2->getNumColumns())
+ return false;
+ break;
+ }
+
case Type::FunctionProto: {
const auto *Proto1 = cast<FunctionProtoType>(T1);
const auto *Proto2 = cast<FunctionProtoType>(T2);
diff --git a/clang/lib/AST/ExprConstant.cpp b/clang/lib/AST/ExprConstant.cpp
index 1c738ff..40c60f1 100644
--- a/clang/lib/AST/ExprConstant.cpp
+++ b/clang/lib/AST/ExprConstant.cpp
@@ -10350,6 +10350,7 @@
case Type::BlockPointer:
case Type::Vector:
case Type::ExtVector:
+ case Type::ConstantMatrix:
case Type::ObjCObject:
case Type::ObjCInterface:
case Type::ObjCObjectPointer:
diff --git a/clang/lib/AST/ItaniumMangle.cpp b/clang/lib/AST/ItaniumMangle.cpp
index d60cacf..dbf004b 100644
--- a/clang/lib/AST/ItaniumMangle.cpp
+++ b/clang/lib/AST/ItaniumMangle.cpp
@@ -2079,6 +2079,8 @@
case Type::DependentSizedExtVector:
case Type::Vector:
case Type::ExtVector:
+ case Type::ConstantMatrix:
+ case Type::DependentSizedMatrix:
case Type::FunctionProto:
case Type::FunctionNoProto:
case Type::Paren:
@@ -3343,6 +3345,31 @@
mangleType(T->getElementType());
}
+void CXXNameMangler::mangleType(const ConstantMatrixType *T) {
+ // Mangle matrix types using a vendor extended type qualifier:
+ // U<Len>matrix_type<Rows><Columns><element type>
+ StringRef VendorQualifier = "matrix_type";
+ Out << "U" << VendorQualifier.size() << VendorQualifier;
+ auto &ASTCtx = getASTContext();
+ unsigned BitWidth = ASTCtx.getTypeSize(ASTCtx.getSizeType());
+ llvm::APSInt Rows(BitWidth);
+ Rows = T->getNumRows();
+ mangleIntegerLiteral(ASTCtx.getSizeType(), Rows);
+ llvm::APSInt Columns(BitWidth);
+ Columns = T->getNumColumns();
+ mangleIntegerLiteral(ASTCtx.getSizeType(), Columns);
+ mangleType(T->getElementType());
+}
+
+void CXXNameMangler::mangleType(const DependentSizedMatrixType *T) {
+ // U<Len>matrix_type<row expr><column expr><element type>
+ StringRef VendorQualifier = "matrix_type";
+ Out << "U" << VendorQualifier.size() << VendorQualifier;
+ mangleTemplateArg(T->getRowExpr());
+ mangleTemplateArg(T->getColumnExpr());
+ mangleType(T->getElementType());
+}
+
void CXXNameMangler::mangleType(const DependentAddressSpaceType *T) {
SplitQualType split = T->getPointeeType().split();
mangleQualifiers(split.Quals, T);
diff --git a/clang/lib/AST/MicrosoftMangle.cpp b/clang/lib/AST/MicrosoftMangle.cpp
index dc5c15f..e3796ac 100644
--- a/clang/lib/AST/MicrosoftMangle.cpp
+++ b/clang/lib/AST/MicrosoftMangle.cpp
@@ -2730,6 +2730,23 @@
<< Range;
}
+void MicrosoftCXXNameMangler::mangleType(const ConstantMatrixType *T,
+ Qualifiers quals, SourceRange Range) {
+ DiagnosticsEngine &Diags = Context.getDiags();
+ unsigned DiagID = Diags.getCustomDiagID(DiagnosticsEngine::Error,
+ "Cannot mangle this matrix type yet");
+ Diags.Report(Range.getBegin(), DiagID) << Range;
+}
+
+void MicrosoftCXXNameMangler::mangleType(const DependentSizedMatrixType *T,
+ Qualifiers quals, SourceRange Range) {
+ DiagnosticsEngine &Diags = Context.getDiags();
+ unsigned DiagID = Diags.getCustomDiagID(
+ DiagnosticsEngine::Error,
+ "Cannot mangle this dependent-sized matrix type yet");
+ Diags.Report(Range.getBegin(), DiagID) << Range;
+}
+
void MicrosoftCXXNameMangler::mangleType(const DependentAddressSpaceType *T,
Qualifiers, SourceRange Range) {
DiagnosticsEngine &Diags = Context.getDiags();
diff --git a/clang/lib/AST/Type.cpp b/clang/lib/AST/Type.cpp
index 3408149..e4d8af9 100644
--- a/clang/lib/AST/Type.cpp
+++ b/clang/lib/AST/Type.cpp
@@ -282,6 +282,53 @@
AddrSpaceExpr->Profile(ID, Context, true);
}
+MatrixType::MatrixType(TypeClass tc, QualType matrixType, QualType canonType,
+ const Expr *RowExpr, const Expr *ColumnExpr)
+ : Type(tc, canonType,
+ (RowExpr ? (matrixType->getDependence() | TypeDependence::Dependent |
+ TypeDependence::Instantiation |
+ (matrixType->isVariablyModifiedType()
+ ? TypeDependence::VariablyModified
+ : TypeDependence::None) |
+ (matrixType->containsUnexpandedParameterPack() ||
+ (RowExpr &&
+ RowExpr->containsUnexpandedParameterPack()) ||
+ (ColumnExpr &&
+ ColumnExpr->containsUnexpandedParameterPack())
+ ? TypeDependence::UnexpandedPack
+ : TypeDependence::None))
+ : matrixType->getDependence())),
+ ElementType(matrixType) {}
+
+ConstantMatrixType::ConstantMatrixType(QualType matrixType, unsigned nRows,
+ unsigned nColumns, QualType canonType)
+ : ConstantMatrixType(ConstantMatrix, matrixType, nRows, nColumns,
+ canonType) {}
+
+ConstantMatrixType::ConstantMatrixType(TypeClass tc, QualType matrixType,
+ unsigned nRows, unsigned nColumns,
+ QualType canonType)
+ : MatrixType(tc, matrixType, canonType) {
+ ConstantMatrixTypeBits.NumRows = nRows;
+ ConstantMatrixTypeBits.NumColumns = nColumns;
+}
+
+DependentSizedMatrixType::DependentSizedMatrixType(
+ const ASTContext &CTX, QualType ElementType, QualType CanonicalType,
+ Expr *RowExpr, Expr *ColumnExpr, SourceLocation loc)
+ : MatrixType(DependentSizedMatrix, ElementType, CanonicalType, RowExpr,
+ ColumnExpr),
+ Context(CTX), RowExpr(RowExpr), ColumnExpr(ColumnExpr), loc(loc) {}
+
+void DependentSizedMatrixType::Profile(llvm::FoldingSetNodeID &ID,
+ const ASTContext &CTX,
+ QualType ElementType, Expr *RowExpr,
+ Expr *ColumnExpr) {
+ ID.AddPointer(ElementType.getAsOpaquePtr());
+ RowExpr->Profile(ID, CTX, true);
+ ColumnExpr->Profile(ID, CTX, true);
+}
+
VectorType::VectorType(QualType vecType, unsigned nElements, QualType canonType,
VectorKind vecKind)
: VectorType(Vector, vecType, nElements, canonType, vecKind) {}
@@ -971,6 +1018,17 @@
return Ctx.getExtVectorType(elementType, T->getNumElements());
}
+ QualType VisitConstantMatrixType(const ConstantMatrixType *T) {
+ QualType elementType = recurse(T->getElementType());
+ if (elementType.isNull())
+ return {};
+ if (elementType.getAsOpaquePtr() == T->getElementType().getAsOpaquePtr())
+ return QualType(T, 0);
+
+ return Ctx.getConstantMatrixType(elementType, T->getNumRows(),
+ T->getNumColumns());
+ }
+
QualType VisitFunctionNoProtoType(const FunctionNoProtoType *T) {
QualType returnType = recurse(T->getReturnType());
if (returnType.isNull())
@@ -1790,6 +1848,14 @@
return Visit(T->getElementType());
}
+ Type *VisitDependentSizedMatrixType(const DependentSizedMatrixType *T) {
+ return Visit(T->getElementType());
+ }
+
+ Type *VisitConstantMatrixType(const ConstantMatrixType *T) {
+ return Visit(T->getElementType());
+ }
+
Type *VisitFunctionProtoType(const FunctionProtoType *T) {
if (Syntactic && T->hasTrailingReturn())
return const_cast<FunctionProtoType*>(T);
@@ -3744,6 +3810,8 @@
case Type::Vector:
case Type::ExtVector:
return Cache::get(cast<VectorType>(T)->getElementType());
+ case Type::ConstantMatrix:
+ return Cache::get(cast<ConstantMatrixType>(T)->getElementType());
case Type::FunctionNoProto:
return Cache::get(cast<FunctionType>(T)->getReturnType());
case Type::FunctionProto: {
@@ -3830,6 +3898,9 @@
case Type::Vector:
case Type::ExtVector:
return computeTypeLinkageInfo(cast<VectorType>(T)->getElementType());
+ case Type::ConstantMatrix:
+ return computeTypeLinkageInfo(
+ cast<ConstantMatrixType>(T)->getElementType());
case Type::FunctionNoProto:
return computeTypeLinkageInfo(cast<FunctionType>(T)->getReturnType());
case Type::FunctionProto: {
@@ -3993,6 +4064,8 @@
case Type::DependentSizedExtVector:
case Type::Vector:
case Type::ExtVector:
+ case Type::ConstantMatrix:
+ case Type::DependentSizedMatrix:
case Type::DependentAddressSpace:
case Type::FunctionProto:
case Type::FunctionNoProto:
diff --git a/clang/lib/AST/TypePrinter.cpp b/clang/lib/AST/TypePrinter.cpp
index cf82d1a..6f6932e 100644
--- a/clang/lib/AST/TypePrinter.cpp
+++ b/clang/lib/AST/TypePrinter.cpp
@@ -256,6 +256,8 @@
case Type::DependentSizedExtVector:
case Type::Vector:
case Type::ExtVector:
+ case Type::ConstantMatrix:
+ case Type::DependentSizedMatrix:
case Type::FunctionProto:
case Type::FunctionNoProto:
case Type::Paren:
@@ -720,6 +722,38 @@
OS << ")))";
}
+void TypePrinter::printConstantMatrixBefore(const ConstantMatrixType *T,
+ raw_ostream &OS) {
+ printBefore(T->getElementType(), OS);
+ OS << " __attribute__((matrix_type(";
+ OS << T->getNumRows() << ", " << T->getNumColumns();
+ OS << ")))";
+}
+
+void TypePrinter::printConstantMatrixAfter(const ConstantMatrixType *T,
+ raw_ostream &OS) {
+ printAfter(T->getElementType(), OS);
+}
+
+void TypePrinter::printDependentSizedMatrixBefore(
+ const DependentSizedMatrixType *T, raw_ostream &OS) {
+ printBefore(T->getElementType(), OS);
+ OS << " __attribute__((matrix_type(";
+ if (T->getRowExpr()) {
+ T->getRowExpr()->printPretty(OS, nullptr, Policy);
+ }
+ OS << ", ";
+ if (T->getColumnExpr()) {
+ T->getColumnExpr()->printPretty(OS, nullptr, Policy);
+ }
+ OS << ")))";
+}
+
+void TypePrinter::printDependentSizedMatrixAfter(
+ const DependentSizedMatrixType *T, raw_ostream &OS) {
+ printAfter(T->getElementType(), OS);
+}
+
void
FunctionProtoType::printExceptionSpecification(raw_ostream &OS,
const PrintingPolicy &Policy)
diff --git a/clang/lib/CodeGen/CGDebugInfo.cpp b/clang/lib/CodeGen/CGDebugInfo.cpp
index e6422a7..0c23b16 100644
--- a/clang/lib/CodeGen/CGDebugInfo.cpp
+++ b/clang/lib/CodeGen/CGDebugInfo.cpp
@@ -2736,6 +2736,23 @@
return DBuilder.createVectorType(Size, Align, ElementTy, SubscriptArray);
}
+llvm::DIType *CGDebugInfo::CreateType(const ConstantMatrixType *Ty,
+ llvm::DIFile *Unit) {
+ // FIXME: Create another debug type for matrices
+ // For the time being, it treats it like a nested ArrayType.
+
+ llvm::DIType *ElementTy = getOrCreateType(Ty->getElementType(), Unit);
+ uint64_t Size = CGM.getContext().getTypeSize(Ty);
+ uint32_t Align = getTypeAlignIfRequired(Ty, CGM.getContext());
+
+ // Create ranges for both dimensions.
+ llvm::SmallVector<llvm::Metadata *, 2> Subscripts;
+ Subscripts.push_back(DBuilder.getOrCreateSubrange(0, Ty->getNumColumns()));
+ Subscripts.push_back(DBuilder.getOrCreateSubrange(0, Ty->getNumRows()));
+ llvm::DINodeArray SubscriptArray = DBuilder.getOrCreateArray(Subscripts);
+ return DBuilder.createArrayType(Size, Align, ElementTy, SubscriptArray);
+}
+
llvm::DIType *CGDebugInfo::CreateType(const ArrayType *Ty, llvm::DIFile *Unit) {
uint64_t Size;
uint32_t Align;
@@ -3129,6 +3146,8 @@
case Type::ExtVector:
case Type::Vector:
return CreateType(cast<VectorType>(Ty), Unit);
+ case Type::ConstantMatrix:
+ return CreateType(cast<ConstantMatrixType>(Ty), Unit);
case Type::ObjCObjectPointer:
return CreateType(cast<ObjCObjectPointerType>(Ty), Unit);
case Type::ObjCObject:
diff --git a/clang/lib/CodeGen/CGDebugInfo.h b/clang/lib/CodeGen/CGDebugInfo.h
index 34164fb..367047e 100644
--- a/clang/lib/CodeGen/CGDebugInfo.h
+++ b/clang/lib/CodeGen/CGDebugInfo.h
@@ -192,6 +192,7 @@
llvm::DIType *CreateType(const ObjCTypeParamType *Ty, llvm::DIFile *Unit);
llvm::DIType *CreateType(const VectorType *Ty, llvm::DIFile *F);
+ llvm::DIType *CreateType(const ConstantMatrixType *Ty, llvm::DIFile *F);
llvm::DIType *CreateType(const ArrayType *Ty, llvm::DIFile *F);
llvm::DIType *CreateType(const LValueReferenceType *Ty, llvm::DIFile *F);
llvm::DIType *CreateType(const RValueReferenceType *Ty, llvm::DIFile *Unit);
diff --git a/clang/lib/CodeGen/CGExpr.cpp b/clang/lib/CodeGen/CGExpr.cpp
index 85c2d31..1701c90 100644
--- a/clang/lib/CodeGen/CGExpr.cpp
+++ b/clang/lib/CodeGen/CGExpr.cpp
@@ -145,8 +145,19 @@
Address CodeGenFunction::CreateMemTemp(QualType Ty, CharUnits Align,
const Twine &Name, Address *Alloca) {
- return CreateTempAlloca(ConvertTypeForMem(Ty), Align, Name,
- /*ArraySize=*/nullptr, Alloca);
+ Address Result = CreateTempAlloca(ConvertTypeForMem(Ty), Align, Name,
+ /*ArraySize=*/nullptr, Alloca);
+
+ if (Ty->isConstantMatrixType()) {
+ auto *ArrayTy = cast<llvm::ArrayType>(Result.getType()->getElementType());
+ auto *VectorTy = llvm::VectorType::get(ArrayTy->getElementType(),
+ ArrayTy->getNumElements());
+
+ Result = Address(
+ Builder.CreateBitCast(Result.getPointer(), VectorTy->getPointerTo()),
+ Result.getAlignment());
+ }
+ return Result;
}
Address CodeGenFunction::CreateMemTempWithoutCast(QualType Ty, CharUnits Align,
@@ -1732,6 +1743,42 @@
return Value;
}
+// Convert the pointer of \p Addr to a pointer to a vector (the value type of
+// MatrixType), if it points to a array (the memory type of MatrixType).
+static Address MaybeConvertMatrixAddress(Address Addr, CodeGenFunction &CGF,
+ bool IsVector = true) {
+ auto *ArrayTy = dyn_cast<llvm::ArrayType>(
+ cast<llvm::PointerType>(Addr.getPointer()->getType())->getElementType());
+ if (ArrayTy && IsVector) {
+ auto *VectorTy = llvm::VectorType::get(ArrayTy->getElementType(),
+ ArrayTy->getNumElements());
+
+ return Address(CGF.Builder.CreateElementBitCast(Addr, VectorTy));
+ }
+ auto *VectorTy = dyn_cast<llvm::VectorType>(
+ cast<llvm::PointerType>(Addr.getPointer()->getType())->getElementType());
+ if (VectorTy && !IsVector) {
+ auto *ArrayTy = llvm::ArrayType::get(VectorTy->getElementType(),
+ VectorTy->getNumElements());
+
+ return Address(CGF.Builder.CreateElementBitCast(Addr, ArrayTy));
+ }
+
+ return Addr;
+}
+
+// Emit a store of a matrix LValue. This may require casting the original
+// pointer to memory address (ArrayType) to a pointer to the value type
+// (VectorType).
+static void EmitStoreOfMatrixScalar(llvm::Value *value, LValue lvalue,
+ bool isInit, CodeGenFunction &CGF) {
+ Address Addr = MaybeConvertMatrixAddress(lvalue.getAddress(CGF), CGF,
+ value->getType()->isVectorTy());
+ CGF.EmitStoreOfScalar(value, Addr, lvalue.isVolatile(), lvalue.getType(),
+ lvalue.getBaseInfo(), lvalue.getTBAAInfo(), isInit,
+ lvalue.isNontemporal());
+}
+
void CodeGenFunction::EmitStoreOfScalar(llvm::Value *Value, Address Addr,
bool Volatile, QualType Ty,
LValueBaseInfo BaseInfo,
@@ -1779,11 +1826,26 @@
void CodeGenFunction::EmitStoreOfScalar(llvm::Value *value, LValue lvalue,
bool isInit) {
+ if (lvalue.getType()->isConstantMatrixType()) {
+ EmitStoreOfMatrixScalar(value, lvalue, isInit, *this);
+ return;
+ }
+
EmitStoreOfScalar(value, lvalue.getAddress(*this), lvalue.isVolatile(),
lvalue.getType(), lvalue.getBaseInfo(),
lvalue.getTBAAInfo(), isInit, lvalue.isNontemporal());
}
+// Emit a load of a LValue of matrix type. This may require casting the pointer
+// to memory address (ArrayType) to a pointer to the value type (VectorType).
+static RValue EmitLoadOfMatrixLValue(LValue LV, SourceLocation Loc,
+ CodeGenFunction &CGF) {
+ assert(LV.getType()->isConstantMatrixType());
+ Address Addr = MaybeConvertMatrixAddress(LV.getAddress(CGF), CGF);
+ LV.setAddress(Addr);
+ return RValue::get(CGF.EmitLoadOfScalar(LV, Loc));
+}
+
/// EmitLoadOfLValue - Given an expression that represents a value lvalue, this
/// method emits the address of the lvalue, then loads the result as an rvalue,
/// returning the rvalue.
@@ -1809,6 +1871,9 @@
if (LV.isSimple()) {
assert(!LV.getType()->isFunctionType());
+ if (LV.getType()->isConstantMatrixType())
+ return EmitLoadOfMatrixLValue(LV, Loc, *this);
+
// Everything needs a load.
return RValue::get(EmitLoadOfScalar(LV, Loc));
}
diff --git a/clang/lib/CodeGen/CodeGenFunction.cpp b/clang/lib/CodeGen/CodeGenFunction.cpp
index 4fcf31a..cbe4823 100644
--- a/clang/lib/CodeGen/CodeGenFunction.cpp
+++ b/clang/lib/CodeGen/CodeGenFunction.cpp
@@ -247,6 +247,7 @@
case Type::MemberPointer:
case Type::Vector:
case Type::ExtVector:
+ case Type::ConstantMatrix:
case Type::FunctionProto:
case Type::FunctionNoProto:
case Type::Enum:
@@ -2000,6 +2001,7 @@
case Type::Complex:
case Type::Vector:
case Type::ExtVector:
+ case Type::ConstantMatrix:
case Type::Record:
case Type::Enum:
case Type::Elaborated:
diff --git a/clang/lib/CodeGen/CodeGenTypes.cpp b/clang/lib/CodeGen/CodeGenTypes.cpp
index 7ae9e91..19b8ff3 100644
--- a/clang/lib/CodeGen/CodeGenTypes.cpp
+++ b/clang/lib/CodeGen/CodeGenTypes.cpp
@@ -82,6 +82,13 @@
/// a type. For example, the scalar representation for _Bool is i1, but the
/// memory representation is usually i8 or i32, depending on the target.
llvm::Type *CodeGenTypes::ConvertTypeForMem(QualType T, bool ForBitField) {
+ if (T->isConstantMatrixType()) {
+ const Type *Ty = Context.getCanonicalType(T).getTypePtr();
+ const ConstantMatrixType *MT = cast<ConstantMatrixType>(Ty);
+ return llvm::ArrayType::get(ConvertType(MT->getElementType()),
+ MT->getNumRows() * MT->getNumColumns());
+ }
+
llvm::Type *R = ConvertType(T);
// If this is a bool type, or an ExtIntType in a bitfield representation,
@@ -646,6 +653,12 @@
VT->getNumElements());
break;
}
+ case Type::ConstantMatrix: {
+ const ConstantMatrixType *MT = cast<ConstantMatrixType>(Ty);
+ ResultType = llvm::VectorType::get(ConvertType(MT->getElementType()),
+ MT->getNumRows() * MT->getNumColumns());
+ break;
+ }
case Type::FunctionNoProto:
case Type::FunctionProto:
ResultType = ConvertFunctionTypeInternal(T);
diff --git a/clang/lib/CodeGen/ItaniumCXXABI.cpp b/clang/lib/CodeGen/ItaniumCXXABI.cpp
index 4a591cf..ceef143 100644
--- a/clang/lib/CodeGen/ItaniumCXXABI.cpp
+++ b/clang/lib/CodeGen/ItaniumCXXABI.cpp
@@ -3223,6 +3223,7 @@
// GCC treats vector and complex types as fundamental types.
case Type::Vector:
case Type::ExtVector:
+ case Type::ConstantMatrix:
case Type::Complex:
case Type::Atomic:
// FIXME: GCC treats block pointers as fundamental types?!
@@ -3458,6 +3459,7 @@
case Type::Builtin:
case Type::Vector:
case Type::ExtVector:
+ case Type::ConstantMatrix:
case Type::Complex:
case Type::BlockPointer:
// Itanium C++ ABI 2.9.5p4:
diff --git a/clang/lib/Driver/ToolChains/Clang.cpp b/clang/lib/Driver/ToolChains/Clang.cpp
index 42d5af7..d6f05be 100644
--- a/clang/lib/Driver/ToolChains/Clang.cpp
+++ b/clang/lib/Driver/ToolChains/Clang.cpp
@@ -4566,6 +4566,13 @@
if (Args.hasFlag(options::OPT_mrtd, options::OPT_mno_rtd, false))
CmdArgs.push_back("-fdefault-calling-conv=stdcall");
+ if (Args.hasArg(options::OPT_fenable_matrix)) {
+ // enable-matrix is needed by both the LangOpts and by LLVM.
+ CmdArgs.push_back("-fenable-matrix");
+ CmdArgs.push_back("-mllvm");
+ CmdArgs.push_back("-enable-matrix");
+ }
+
CodeGenOptions::FramePointerKind FPKeepKind =
getFramePointerKind(Args, RawTriple);
const char *FPKeepKindStr = nullptr;
diff --git a/clang/lib/Frontend/CompilerInvocation.cpp b/clang/lib/Frontend/CompilerInvocation.cpp
index c5166ad..ec41547 100644
--- a/clang/lib/Frontend/CompilerInvocation.cpp
+++ b/clang/lib/Frontend/CompilerInvocation.cpp
@@ -3337,6 +3337,8 @@
Opts.CompleteMemberPointers = Args.hasArg(OPT_fcomplete_member_pointers);
Opts.BuildingPCHWithObjectFile = Args.hasArg(OPT_building_pch_with_obj);
+ Opts.MatrixTypes = Args.hasArg(OPT_fenable_matrix);
+
Opts.MaxTokens = getLastArgIntValue(Args, OPT_fmax_tokens_EQ, 0, Diags);
if (Arg *A = Args.getLastArg(OPT_msign_return_address_EQ)) {
diff --git a/clang/lib/Sema/SemaExpr.cpp b/clang/lib/Sema/SemaExpr.cpp
index 948880e..6f25f71 100644
--- a/clang/lib/Sema/SemaExpr.cpp
+++ b/clang/lib/Sema/SemaExpr.cpp
@@ -4257,6 +4257,7 @@
case Type::Complex:
case Type::Vector:
case Type::ExtVector:
+ case Type::ConstantMatrix:
case Type::Record:
case Type::Enum:
case Type::Elaborated:
diff --git a/clang/lib/Sema/SemaLookup.cpp b/clang/lib/Sema/SemaLookup.cpp
index aa83f82..dbbd190 100644
--- a/clang/lib/Sema/SemaLookup.cpp
+++ b/clang/lib/Sema/SemaLookup.cpp
@@ -2966,6 +2966,7 @@
// These are fundamental types.
case Type::Vector:
case Type::ExtVector:
+ case Type::ConstantMatrix:
case Type::Complex:
case Type::ExtInt:
break;
diff --git a/clang/lib/Sema/SemaTemplate.cpp b/clang/lib/Sema/SemaTemplate.cpp
old mode 100755
new mode 100644
index 66b8e4d..c2324f9
--- a/clang/lib/Sema/SemaTemplate.cpp
+++ b/clang/lib/Sema/SemaTemplate.cpp
@@ -5867,6 +5867,11 @@
return Visit(T->getElementType());
}
+bool UnnamedLocalNoLinkageFinder::VisitDependentSizedMatrixType(
+ const DependentSizedMatrixType *T) {
+ return Visit(T->getElementType());
+}
+
bool UnnamedLocalNoLinkageFinder::VisitDependentAddressSpaceType(
const DependentAddressSpaceType *T) {
return Visit(T->getPointeeType());
@@ -5885,6 +5890,11 @@
return Visit(T->getElementType());
}
+bool UnnamedLocalNoLinkageFinder::VisitConstantMatrixType(
+ const ConstantMatrixType *T) {
+ return Visit(T->getElementType());
+}
+
bool UnnamedLocalNoLinkageFinder::VisitFunctionProtoType(
const FunctionProtoType* T) {
for (const auto &A : T->param_types()) {
diff --git a/clang/lib/Sema/SemaTemplateDeduction.cpp b/clang/lib/Sema/SemaTemplateDeduction.cpp
index e1d438f..19f8248 100644
--- a/clang/lib/Sema/SemaTemplateDeduction.cpp
+++ b/clang/lib/Sema/SemaTemplateDeduction.cpp
@@ -2057,6 +2057,101 @@
// (clang extension)
//
+ // T __attribute__((matrix_type(<integral constant>,
+ // <integral constant>)))
+ case Type::ConstantMatrix: {
+ const ConstantMatrixType *MatrixArg = dyn_cast<ConstantMatrixType>(Arg);
+ if (!MatrixArg)
+ return Sema::TDK_NonDeducedMismatch;
+
+ const ConstantMatrixType *MatrixParam = cast<ConstantMatrixType>(Param);
+ // Check that the dimensions are the same
+ if (MatrixParam->getNumRows() != MatrixArg->getNumRows() ||
+ MatrixParam->getNumColumns() != MatrixArg->getNumColumns()) {
+ return Sema::TDK_NonDeducedMismatch;
+ }
+ // Perform deduction on element types.
+ return DeduceTemplateArgumentsByTypeMatch(
+ S, TemplateParams, MatrixParam->getElementType(),
+ MatrixArg->getElementType(), Info, Deduced, TDF);
+ }
+
+ case Type::DependentSizedMatrix: {
+ const MatrixType *MatrixArg = dyn_cast<MatrixType>(Arg);
+ if (!MatrixArg)
+ return Sema::TDK_NonDeducedMismatch;
+
+ // Check the element type of the matrixes.
+ const DependentSizedMatrixType *MatrixParam =
+ cast<DependentSizedMatrixType>(Param);
+ if (Sema::TemplateDeductionResult Result =
+ DeduceTemplateArgumentsByTypeMatch(
+ S, TemplateParams, MatrixParam->getElementType(),
+ MatrixArg->getElementType(), Info, Deduced, TDF))
+ return Result;
+
+ // Try to deduce a matrix dimension.
+ auto DeduceMatrixArg =
+ [&S, &Info, &Deduced, &TemplateParams](
+ Expr *ParamExpr, const MatrixType *Arg,
+ unsigned (ConstantMatrixType::*GetArgDimension)() const,
+ Expr *(DependentSizedMatrixType::*GetArgDimensionExpr)() const) {
+ const auto *ArgConstMatrix = dyn_cast<ConstantMatrixType>(Arg);
+ const auto *ArgDepMatrix = dyn_cast<DependentSizedMatrixType>(Arg);
+ if (!ParamExpr->isValueDependent()) {
+ llvm::APSInt ParamConst(
+ S.Context.getTypeSize(S.Context.getSizeType()));
+ if (!ParamExpr->isIntegerConstantExpr(ParamConst, S.Context))
+ return Sema::TDK_NonDeducedMismatch;
+
+ if (ArgConstMatrix) {
+ if ((ArgConstMatrix->*GetArgDimension)() == ParamConst)
+ return Sema::TDK_Success;
+ return Sema::TDK_NonDeducedMismatch;
+ }
+
+ Expr *ArgExpr = (ArgDepMatrix->*GetArgDimensionExpr)();
+ llvm::APSInt ArgConst(
+ S.Context.getTypeSize(S.Context.getSizeType()));
+ if (!ArgExpr->isValueDependent() &&
+ ArgExpr->isIntegerConstantExpr(ArgConst, S.Context) &&
+ ArgConst == ParamConst)
+ return Sema::TDK_Success;
+ return Sema::TDK_NonDeducedMismatch;
+ }
+
+ NonTypeTemplateParmDecl *NTTP =
+ getDeducedParameterFromExpr(Info, ParamExpr);
+ if (!NTTP)
+ return Sema::TDK_Success;
+
+ if (ArgConstMatrix) {
+ llvm::APSInt ArgConst(
+ S.Context.getTypeSize(S.Context.getSizeType()));
+ ArgConst = (ArgConstMatrix->*GetArgDimension)();
+ return DeduceNonTypeTemplateArgument(
+ S, TemplateParams, NTTP, ArgConst, S.Context.getSizeType(),
+ /*ArrayBound=*/true, Info, Deduced);
+ }
+
+ return DeduceNonTypeTemplateArgument(
+ S, TemplateParams, NTTP, (ArgDepMatrix->*GetArgDimensionExpr)(),
+ Info, Deduced);
+ };
+
+ auto Result = DeduceMatrixArg(MatrixParam->getRowExpr(), MatrixArg,
+ &ConstantMatrixType::getNumRows,
+ &DependentSizedMatrixType::getRowExpr);
+ if (Result)
+ return Result;
+
+ return DeduceMatrixArg(MatrixParam->getColumnExpr(), MatrixArg,
+ &ConstantMatrixType::getNumColumns,
+ &DependentSizedMatrixType::getColumnExpr);
+ }
+
+ // (clang extension)
+ //
// T __attribute__(((address_space(N))))
case Type::DependentAddressSpace: {
const DependentAddressSpaceType *AddressSpaceParam =
@@ -5723,6 +5818,24 @@
break;
}
+ case Type::ConstantMatrix: {
+ const ConstantMatrixType *MatType = cast<ConstantMatrixType>(T);
+ MarkUsedTemplateParameters(Ctx, MatType->getElementType(), OnlyDeduced,
+ Depth, Used);
+ break;
+ }
+
+ case Type::DependentSizedMatrix: {
+ const DependentSizedMatrixType *MatType = cast<DependentSizedMatrixType>(T);
+ MarkUsedTemplateParameters(Ctx, MatType->getElementType(), OnlyDeduced,
+ Depth, Used);
+ MarkUsedTemplateParameters(Ctx, MatType->getRowExpr(), OnlyDeduced, Depth,
+ Used);
+ MarkUsedTemplateParameters(Ctx, MatType->getColumnExpr(), OnlyDeduced,
+ Depth, Used);
+ break;
+ }
+
case Type::FunctionProto: {
const FunctionProtoType *Proto = cast<FunctionProtoType>(T);
MarkUsedTemplateParameters(Ctx, Proto->getReturnType(), OnlyDeduced, Depth,
diff --git a/clang/lib/Sema/SemaType.cpp b/clang/lib/Sema/SemaType.cpp
index 338e335..df8ad7a 100644
--- a/clang/lib/Sema/SemaType.cpp
+++ b/clang/lib/Sema/SemaType.cpp
@@ -2492,14 +2492,15 @@
if (!VecSize.isIntN(61)) {
// Bit size will overflow uint64.
Diag(AttrLoc, diag::err_attribute_size_too_large)
- << SizeExpr->getSourceRange();
+ << SizeExpr->getSourceRange() << "vector";
return QualType();
}
uint64_t VectorSizeBits = VecSize.getZExtValue() * 8;
unsigned TypeSize = static_cast<unsigned>(Context.getTypeSize(CurType));
if (VectorSizeBits == 0) {
- Diag(AttrLoc, diag::err_attribute_zero_size) << SizeExpr->getSourceRange();
+ Diag(AttrLoc, diag::err_attribute_zero_size)
+ << SizeExpr->getSourceRange() << "vector";
return QualType();
}
@@ -2511,7 +2512,7 @@
if (VectorSizeBits / TypeSize > std::numeric_limits<uint32_t>::max()) {
Diag(AttrLoc, diag::err_attribute_size_too_large)
- << SizeExpr->getSourceRange();
+ << SizeExpr->getSourceRange() << "vector";
return QualType();
}
@@ -2549,7 +2550,7 @@
if (!vecSize.isIntN(32)) {
Diag(AttrLoc, diag::err_attribute_size_too_large)
- << ArraySize->getSourceRange();
+ << ArraySize->getSourceRange() << "vector";
return QualType();
}
// Unlike gcc's vector_size attribute, the size is specified as the
@@ -2558,7 +2559,7 @@
if (vectorSize == 0) {
Diag(AttrLoc, diag::err_attribute_zero_size)
- << ArraySize->getSourceRange();
+ << ArraySize->getSourceRange() << "vector";
return QualType();
}
@@ -2568,6 +2569,84 @@
return Context.getDependentSizedExtVectorType(T, ArraySize, AttrLoc);
}
+QualType Sema::BuildMatrixType(QualType ElementTy, Expr *NumRows, Expr *NumCols,
+ SourceLocation AttrLoc) {
+ assert(Context.getLangOpts().MatrixTypes &&
+ "Should never build a matrix type when it is disabled");
+
+ if (NumRows->isTypeDependent() || NumCols->isTypeDependent() ||
+ NumRows->isValueDependent() || NumCols->isValueDependent())
+ return Context.getDependentSizedMatrixType(ElementTy, NumRows, NumCols,
+ AttrLoc);
+
+ // Check element type, if it is not dependent.
+ if (!ElementTy->isDependentType() &&
+ !MatrixType::isValidElementType(ElementTy)) {
+ Diag(AttrLoc, diag::err_attribute_invalid_matrix_type) << ElementTy;
+ return QualType();
+ }
+
+ // Both row and column values can only be 20 bit wide currently.
+ llvm::APSInt ValueRows(32), ValueColumns(32);
+
+ bool const RowsIsInteger = NumRows->isIntegerConstantExpr(ValueRows, Context);
+ bool const ColumnsIsInteger =
+ NumCols->isIntegerConstantExpr(ValueColumns, Context);
+
+ auto const RowRange = NumRows->getSourceRange();
+ auto const ColRange = NumCols->getSourceRange();
+
+ // Both are row and column expressions are invalid.
+ if (!RowsIsInteger && !ColumnsIsInteger) {
+ Diag(AttrLoc, diag::err_attribute_argument_type)
+ << "matrix_type" << AANT_ArgumentIntegerConstant << RowRange
+ << ColRange;
+ return QualType();
+ }
+
+ // Only the row expression is invalid.
+ if (!RowsIsInteger) {
+ Diag(AttrLoc, diag::err_attribute_argument_type)
+ << "matrix_type" << AANT_ArgumentIntegerConstant << RowRange;
+ return QualType();
+ }
+
+ // Only the column expression is invalid.
+ if (!ColumnsIsInteger) {
+ Diag(AttrLoc, diag::err_attribute_argument_type)
+ << "matrix_type" << AANT_ArgumentIntegerConstant << ColRange;
+ return QualType();
+ }
+
+ // Check the matrix dimensions.
+ unsigned MatrixRows = static_cast<unsigned>(ValueRows.getZExtValue());
+ unsigned MatrixColumns = static_cast<unsigned>(ValueColumns.getZExtValue());
+ if (MatrixRows == 0 && MatrixColumns == 0) {
+ Diag(AttrLoc, diag::err_attribute_zero_size)
+ << "matrix" << RowRange << ColRange;
+ return QualType();
+ }
+ if (MatrixRows == 0) {
+ Diag(AttrLoc, diag::err_attribute_zero_size) << "matrix" << RowRange;
+ return QualType();
+ }
+ if (MatrixColumns == 0) {
+ Diag(AttrLoc, diag::err_attribute_zero_size) << "matrix" << ColRange;
+ return QualType();
+ }
+ if (!ConstantMatrixType::isDimensionValid(MatrixRows)) {
+ Diag(AttrLoc, diag::err_attribute_size_too_large)
+ << RowRange << "matrix row";
+ return QualType();
+ }
+ if (!ConstantMatrixType::isDimensionValid(MatrixColumns)) {
+ Diag(AttrLoc, diag::err_attribute_size_too_large)
+ << ColRange << "matrix column";
+ return QualType();
+ }
+ return Context.getConstantMatrixType(ElementTy, MatrixRows, MatrixColumns);
+}
+
bool Sema::CheckFunctionReturnType(QualType T, SourceLocation Loc) {
if (T->isArrayType() || T->isFunctionType()) {
Diag(Loc, diag::err_func_returning_array_function)
@@ -6013,6 +6092,21 @@
"no address_space attribute found at the expected location!");
}
+static void fillMatrixTypeLoc(MatrixTypeLoc MTL,
+ const ParsedAttributesView &Attrs) {
+ for (const ParsedAttr &AL : Attrs) {
+ if (AL.getKind() == ParsedAttr::AT_MatrixType) {
+ MTL.setAttrNameLoc(AL.getLoc());
+ MTL.setAttrRowOperand(AL.getArgAsExpr(0));
+ MTL.setAttrColumnOperand(AL.getArgAsExpr(1));
+ MTL.setAttrOperandParensRange(SourceRange());
+ return;
+ }
+ }
+
+ llvm_unreachable("no matrix_type attribute found at the expected location!");
+}
+
/// Create and instantiate a TypeSourceInfo with type source information.
///
/// \param T QualType referring to the type as written in source code.
@@ -6061,6 +6155,9 @@
CurrTL = TL.getPointeeTypeLoc().getUnqualifiedLoc();
}
+ if (MatrixTypeLoc TL = CurrTL.getAs<MatrixTypeLoc>())
+ fillMatrixTypeLoc(TL, D.getTypeObject(i).getAttrs());
+
// FIXME: Ordering here?
while (AdjustedTypeLoc TL = CurrTL.getAs<AdjustedTypeLoc>())
CurrTL = TL.getNextTypeLoc().getUnqualifiedLoc();
@@ -7706,6 +7803,68 @@
}
}
+/// HandleMatrixTypeAttr - "matrix_type" attribute, like ext_vector_type
+static void HandleMatrixTypeAttr(QualType &CurType, const ParsedAttr &Attr,
+ Sema &S) {
+ if (!S.getLangOpts().MatrixTypes) {
+ S.Diag(Attr.getLoc(), diag::err_builtin_matrix_disabled);
+ return;
+ }
+
+ if (Attr.getNumArgs() != 2) {
+ S.Diag(Attr.getLoc(), diag::err_attribute_wrong_number_arguments)
+ << Attr << 2;
+ return;
+ }
+
+ Expr *RowsExpr = nullptr;
+ Expr *ColsExpr = nullptr;
+
+ // TODO: Refactor parameter extraction into separate function
+ // Get the number of rows
+ if (Attr.isArgIdent(0)) {
+ CXXScopeSpec SS;
+ SourceLocation TemplateKeywordLoc;
+ UnqualifiedId id;
+ id.setIdentifier(Attr.getArgAsIdent(0)->Ident, Attr.getLoc());
+ ExprResult Rows = S.ActOnIdExpression(S.getCurScope(), SS,
+ TemplateKeywordLoc, id, false, false);
+
+ if (Rows.isInvalid())
+ // TODO: maybe a good error message would be nice here
+ return;
+ RowsExpr = Rows.get();
+ } else {
+ assert(Attr.isArgExpr(0) &&
+ "Argument to should either be an identity or expression");
+ RowsExpr = Attr.getArgAsExpr(0);
+ }
+
+ // Get the number of columns
+ if (Attr.isArgIdent(1)) {
+ CXXScopeSpec SS;
+ SourceLocation TemplateKeywordLoc;
+ UnqualifiedId id;
+ id.setIdentifier(Attr.getArgAsIdent(1)->Ident, Attr.getLoc());
+ ExprResult Columns = S.ActOnIdExpression(
+ S.getCurScope(), SS, TemplateKeywordLoc, id, false, false);
+
+ if (Columns.isInvalid())
+ // TODO: a good error message would be nice here
+ return;
+ RowsExpr = Columns.get();
+ } else {
+ assert(Attr.isArgExpr(1) &&
+ "Argument to should either be an identity or expression");
+ ColsExpr = Attr.getArgAsExpr(1);
+ }
+
+ // Create the matrix type.
+ QualType T = S.BuildMatrixType(CurType, RowsExpr, ColsExpr, Attr.getLoc());
+ if (!T.isNull())
+ CurType = T;
+}
+
static void HandleLifetimeBoundAttr(TypeProcessingState &State,
QualType &CurType,
ParsedAttr &Attr) {
@@ -7857,6 +8016,11 @@
break;
}
+ case ParsedAttr::AT_MatrixType:
+ HandleMatrixTypeAttr(type, attr, state.getSema());
+ attr.setUsedAsTypeAttr();
+ break;
+
MS_TYPE_ATTRS_CASELIST:
if (!handleMSPointerTypeQualifierAttr(state, attr, type))
attr.setUsedAsTypeAttr();
diff --git a/clang/lib/Sema/TreeTransform.h b/clang/lib/Sema/TreeTransform.h
index 0987fee..99f1784 100644
--- a/clang/lib/Sema/TreeTransform.h
+++ b/clang/lib/Sema/TreeTransform.h
@@ -894,6 +894,16 @@
Expr *SizeExpr,
SourceLocation AttributeLoc);
+ /// Build a new matrix type given the element type and dimensions.
+ QualType RebuildConstantMatrixType(QualType ElementType, unsigned NumRows,
+ unsigned NumColumns);
+
+ /// Build a new matrix type given the type and dependently-defined
+ /// dimensions.
+ QualType RebuildDependentSizedMatrixType(QualType ElementType, Expr *RowExpr,
+ Expr *ColumnExpr,
+ SourceLocation AttributeLoc);
+
/// Build a new DependentAddressSpaceType or return the pointee
/// type variable with the correct address space (retrieved from
/// AddrSpaceExpr) applied to it. The former will be returned in cases
@@ -5180,6 +5190,86 @@
}
template <typename Derived>
+QualType
+TreeTransform<Derived>::TransformConstantMatrixType(TypeLocBuilder &TLB,
+ ConstantMatrixTypeLoc TL) {
+ const ConstantMatrixType *T = TL.getTypePtr();
+ QualType ElementType = getDerived().TransformType(T->getElementType());
+ if (ElementType.isNull())
+ return QualType();
+
+ QualType Result = TL.getType();
+ if (getDerived().AlwaysRebuild() || ElementType != T->getElementType()) {
+ Result = getDerived().RebuildConstantMatrixType(
+ ElementType, T->getNumRows(), T->getNumColumns());
+ if (Result.isNull())
+ return QualType();
+ }
+
+ ConstantMatrixTypeLoc NewTL = TLB.push<ConstantMatrixTypeLoc>(Result);
+ NewTL.setAttrNameLoc(TL.getAttrNameLoc());
+ NewTL.setAttrOperandParensRange(TL.getAttrOperandParensRange());
+ NewTL.setAttrRowOperand(TL.getAttrRowOperand());
+ NewTL.setAttrColumnOperand(TL.getAttrColumnOperand());
+
+ return Result;
+}
+
+template <typename Derived>
+QualType TreeTransform<Derived>::TransformDependentSizedMatrixType(
+ TypeLocBuilder &TLB, DependentSizedMatrixTypeLoc TL) {
+ const DependentSizedMatrixType *T = TL.getTypePtr();
+
+ QualType ElementType = getDerived().TransformType(T->getElementType());
+ if (ElementType.isNull()) {
+ return QualType();
+ }
+
+ // Matrix dimensions are constant expressions.
+ EnterExpressionEvaluationContext Unevaluated(
+ SemaRef, Sema::ExpressionEvaluationContext::ConstantEvaluated);
+
+ Expr *origRows = TL.getAttrRowOperand();
+ if (!origRows)
+ origRows = T->getRowExpr();
+ Expr *origColumns = TL.getAttrColumnOperand();
+ if (!origColumns)
+ origColumns = T->getColumnExpr();
+
+ ExprResult rowResult = getDerived().TransformExpr(origRows);
+ rowResult = SemaRef.ActOnConstantExpression(rowResult);
+ if (rowResult.isInvalid())
+ return QualType();
+
+ ExprResult columnResult = getDerived().TransformExpr(origColumns);
+ columnResult = SemaRef.ActOnConstantExpression(columnResult);
+ if (columnResult.isInvalid())
+ return QualType();
+
+ Expr *rows = rowResult.get();
+ Expr *columns = columnResult.get();
+
+ QualType Result = TL.getType();
+ if (getDerived().AlwaysRebuild() || ElementType != T->getElementType() ||
+ rows != origRows || columns != origColumns) {
+ Result = getDerived().RebuildDependentSizedMatrixType(
+ ElementType, rows, columns, T->getAttributeLoc());
+
+ if (Result.isNull())
+ return QualType();
+ }
+
+ // We might have any sort of matrix type now, but fortunately they
+ // all have the same location layout.
+ MatrixTypeLoc NewTL = TLB.push<MatrixTypeLoc>(Result);
+ NewTL.setAttrNameLoc(TL.getAttrNameLoc());
+ NewTL.setAttrOperandParensRange(TL.getAttrOperandParensRange());
+ NewTL.setAttrRowOperand(rows);
+ NewTL.setAttrColumnOperand(columns);
+ return Result;
+}
+
+template <typename Derived>
QualType TreeTransform<Derived>::TransformDependentAddressSpaceType(
TypeLocBuilder &TLB, DependentAddressSpaceTypeLoc TL) {
const DependentAddressSpaceType *T = TL.getTypePtr();
@@ -13750,6 +13840,21 @@
return SemaRef.BuildExtVectorType(ElementType, SizeExpr, AttributeLoc);
}
+template <typename Derived>
+QualType TreeTransform<Derived>::RebuildConstantMatrixType(
+ QualType ElementType, unsigned NumRows, unsigned NumColumns) {
+ return SemaRef.Context.getConstantMatrixType(ElementType, NumRows,
+ NumColumns);
+}
+
+template <typename Derived>
+QualType TreeTransform<Derived>::RebuildDependentSizedMatrixType(
+ QualType ElementType, Expr *RowExpr, Expr *ColumnExpr,
+ SourceLocation AttributeLoc) {
+ return SemaRef.BuildMatrixType(ElementType, RowExpr, ColumnExpr,
+ AttributeLoc);
+}
+
template<typename Derived>
QualType TreeTransform<Derived>::RebuildFunctionProtoType(
QualType T,
diff --git a/clang/lib/Serialization/ASTReader.cpp b/clang/lib/Serialization/ASTReader.cpp
index 187665b..3f41646 100644
--- a/clang/lib/Serialization/ASTReader.cpp
+++ b/clang/lib/Serialization/ASTReader.cpp
@@ -6554,6 +6554,21 @@
TL.setNameLoc(readSourceLocation());
}
+void TypeLocReader::VisitConstantMatrixTypeLoc(ConstantMatrixTypeLoc TL) {
+ TL.setAttrNameLoc(readSourceLocation());
+ TL.setAttrOperandParensRange(Reader.readSourceRange());
+ TL.setAttrRowOperand(Reader.readExpr());
+ TL.setAttrColumnOperand(Reader.readExpr());
+}
+
+void TypeLocReader::VisitDependentSizedMatrixTypeLoc(
+ DependentSizedMatrixTypeLoc TL) {
+ TL.setAttrNameLoc(readSourceLocation());
+ TL.setAttrOperandParensRange(Reader.readSourceRange());
+ TL.setAttrRowOperand(Reader.readExpr());
+ TL.setAttrColumnOperand(Reader.readExpr());
+}
+
void TypeLocReader::VisitFunctionTypeLoc(FunctionTypeLoc TL) {
TL.setLocalRangeBegin(readSourceLocation());
TL.setLParenLoc(readSourceLocation());
diff --git a/clang/lib/Serialization/ASTWriter.cpp b/clang/lib/Serialization/ASTWriter.cpp
index b2281ab..f4e54e4 100644
--- a/clang/lib/Serialization/ASTWriter.cpp
+++ b/clang/lib/Serialization/ASTWriter.cpp
@@ -288,6 +288,25 @@
Record.AddSourceLocation(TL.getNameLoc());
}
+void TypeLocWriter::VisitConstantMatrixTypeLoc(ConstantMatrixTypeLoc TL) {
+ Record.AddSourceLocation(TL.getAttrNameLoc());
+ SourceRange range = TL.getAttrOperandParensRange();
+ Record.AddSourceLocation(range.getBegin());
+ Record.AddSourceLocation(range.getEnd());
+ Record.AddStmt(TL.getAttrRowOperand());
+ Record.AddStmt(TL.getAttrColumnOperand());
+}
+
+void TypeLocWriter::VisitDependentSizedMatrixTypeLoc(
+ DependentSizedMatrixTypeLoc TL) {
+ Record.AddSourceLocation(TL.getAttrNameLoc());
+ SourceRange range = TL.getAttrOperandParensRange();
+ Record.AddSourceLocation(range.getBegin());
+ Record.AddSourceLocation(range.getEnd());
+ Record.AddStmt(TL.getAttrRowOperand());
+ Record.AddStmt(TL.getAttrColumnOperand());
+}
+
void TypeLocWriter::VisitFunctionTypeLoc(FunctionTypeLoc TL) {
Record.AddSourceLocation(TL.getLocalRangeBegin());
Record.AddSourceLocation(TL.getLParenLoc());