[ARM] Add __bf16 as new Bfloat16 C Type
Summary:
This patch upstreams support for a new storage only bfloat16 C type.
This type is used to implement primitive support for bfloat16 data, in
line with the Bfloat16 extension of the Armv8.6-a architecture, as
detailed here:
https://community.arm.com/developer/ip-products/processors/b/processors-ip-blog/posts/arm-architecture-developments-armv8-6-a
The bfloat type, and its properties are specified in the Arm Architecture
Reference Manual:
https://developer.arm.com/docs/ddi0487/latest/arm-architecture-reference-manual-armv8-for-armv8-a-architecture-profile
In detail this patch:
- introduces an opaque, storage-only C-type __bf16, which introduces a new bfloat IR type.
This is part of a patch series, starting with command-line and Bfloat16
assembly support. The subsequent patches will upstream intrinsics
support for BFloat16, followed by Matrix Multiplication and the
remaining Virtualization features of the armv8.6-a architecture.
The following people contributed to this patch:
- Luke Cheeseman
- Momchil Velikov
- Alexandros Lamprineas
- Luke Geeson
- Simon Tatham
- Ties Stuij
Reviewers: SjoerdMeijer, rjmccall, rsmith, liutianle, RKSimon, craig.topper, jfb, LukeGeeson, fpetrogalli
Reviewed By: SjoerdMeijer
Subscribers: labrinea, majnemer, asmith, dexonsmith, kristof.beyls, arphaman, danielkiss, cfe-commits
Tags: #clang
Differential Revision: https://reviews.llvm.org/D76077
diff --git a/clang/lib/AST/ASTContext.cpp b/clang/lib/AST/ASTContext.cpp
index eccc3e2..60482ea 100644
--- a/clang/lib/AST/ASTContext.cpp
+++ b/clang/lib/AST/ASTContext.cpp
@@ -100,7 +100,7 @@
using namespace clang;
enum FloatingRank {
- Float16Rank, HalfRank, FloatRank, DoubleRank, LongDoubleRank, Float128Rank
+ BFloat16Rank, Float16Rank, HalfRank, FloatRank, DoubleRank, LongDoubleRank, Float128Rank
};
/// \returns location that is relevant when searching for Doc comments related
@@ -1448,6 +1448,8 @@
// half type (OpenCL 6.1.1.1) / ARM NEON __fp16
InitBuiltinType(HalfTy, BuiltinType::Half);
+ InitBuiltinType(BFloat16Ty, BuiltinType::BFloat16);
+
// Builtin type used to help define __builtin_va_list.
VaListTagDecl = nullptr;
@@ -1651,6 +1653,8 @@
switch (T->castAs<BuiltinType>()->getKind()) {
default:
llvm_unreachable("Not a floating point type!");
+ case BuiltinType::BFloat16:
+ return Target->getBFloat16Format();
case BuiltinType::Float16:
case BuiltinType::Half:
return Target->getHalfFormat();
@@ -2045,6 +2049,10 @@
Width = Target->getLongFractWidth();
Align = Target->getLongFractAlign();
break;
+ case BuiltinType::BFloat16:
+ Width = Target->getBFloat16Width();
+ Align = Target->getBFloat16Align();
+ break;
case BuiltinType::Float16:
case BuiltinType::Half:
if (Target->hasFloat16Type() || !getLangOpts().OpenMP ||
@@ -5984,6 +5992,7 @@
case BuiltinType::Double: return DoubleRank;
case BuiltinType::LongDouble: return LongDoubleRank;
case BuiltinType::Float128: return Float128Rank;
+ case BuiltinType::BFloat16: return BFloat16Rank;
}
}
@@ -5996,6 +6005,7 @@
FloatingRank EltRank = getFloatingRank(Size);
if (Domain->isComplexType()) {
switch (EltRank) {
+ case BFloat16Rank: llvm_unreachable("Complex bfloat16 is not supported");
case Float16Rank:
case HalfRank: llvm_unreachable("Complex half is not supported");
case FloatRank: return FloatComplexTy;
@@ -6008,6 +6018,7 @@
assert(Domain->isRealFloatingType() && "Unknown domain!");
switch (EltRank) {
case Float16Rank: return HalfTy;
+ case BFloat16Rank: return BFloat16Ty;
case HalfRank: return HalfTy;
case FloatRank: return FloatTy;
case DoubleRank: return DoubleTy;
@@ -6985,6 +6996,7 @@
case BuiltinType::LongDouble: return 'D';
case BuiltinType::NullPtr: return '*'; // like char*
+ case BuiltinType::BFloat16:
case BuiltinType::Float16:
case BuiltinType::Float128:
case BuiltinType::Half:
@@ -9892,6 +9904,11 @@
// Read the base type.
switch (*Str++) {
default: llvm_unreachable("Unknown builtin type letter!");
+ case 'y':
+ assert(HowLong == 0 && !Signed && !Unsigned &&
+ "Bad modifiers used with 'y'!");
+ Type = Context.BFloat16Ty;
+ break;
case 'v':
assert(HowLong == 0 && !Signed && !Unsigned &&
"Bad modifiers used with 'v'!");
diff --git a/clang/lib/AST/ItaniumMangle.cpp b/clang/lib/AST/ItaniumMangle.cpp
index 487ff8f..aae33c5 100644
--- a/clang/lib/AST/ItaniumMangle.cpp
+++ b/clang/lib/AST/ItaniumMangle.cpp
@@ -2764,6 +2764,11 @@
Out << TI->getFloat128Mangling();
break;
}
+ case BuiltinType::BFloat16: {
+ const TargetInfo *TI = &getASTContext().getTargetInfo();
+ Out << TI->getBFloat16Mangling();
+ break;
+ }
case BuiltinType::NullPtr:
Out << "Dn";
break;
@@ -3179,7 +3184,8 @@
case BuiltinType::ULongLong: EltName = "uint64_t"; break;
case BuiltinType::Double: EltName = "float64_t"; break;
case BuiltinType::Float: EltName = "float32_t"; break;
- case BuiltinType::Half: EltName = "float16_t";break;
+ case BuiltinType::Half: EltName = "float16_t"; break;
+ case BuiltinType::BFloat16: EltName = "bfloat16_t"; break;
default:
llvm_unreachable("unexpected Neon vector element type");
}
@@ -3231,6 +3237,8 @@
return "Float32";
case BuiltinType::Double:
return "Float64";
+ case BuiltinType::BFloat16:
+ return "BFloat16";
default:
llvm_unreachable("Unexpected vector element base type");
}
diff --git a/clang/lib/AST/MicrosoftMangle.cpp b/clang/lib/AST/MicrosoftMangle.cpp
index e3796ac..ca628b3 100644
--- a/clang/lib/AST/MicrosoftMangle.cpp
+++ b/clang/lib/AST/MicrosoftMangle.cpp
@@ -2114,6 +2114,7 @@
case BuiltinType::SatUShortFract:
case BuiltinType::SatUFract:
case BuiltinType::SatULongFract:
+ case BuiltinType::BFloat16:
case BuiltinType::Float128: {
DiagnosticsEngine &Diags = Context.getDiags();
unsigned DiagID = Diags.getCustomDiagID(
diff --git a/clang/lib/AST/NSAPI.cpp b/clang/lib/AST/NSAPI.cpp
index ac06d33..ace7f1c 100644
--- a/clang/lib/AST/NSAPI.cpp
+++ b/clang/lib/AST/NSAPI.cpp
@@ -486,6 +486,7 @@
case BuiltinType::OMPArraySection:
case BuiltinType::OMPArrayShaping:
case BuiltinType::OMPIterator:
+ case BuiltinType::BFloat16:
break;
}
diff --git a/clang/lib/AST/PrintfFormatString.cpp b/clang/lib/AST/PrintfFormatString.cpp
index d13ff0a..f3ac181 100644
--- a/clang/lib/AST/PrintfFormatString.cpp
+++ b/clang/lib/AST/PrintfFormatString.cpp
@@ -752,6 +752,7 @@
case BuiltinType::UInt128:
case BuiltinType::Int128:
case BuiltinType::Half:
+ case BuiltinType::BFloat16:
case BuiltinType::Float16:
case BuiltinType::Float128:
case BuiltinType::ShortAccum:
diff --git a/clang/lib/AST/Type.cpp b/clang/lib/AST/Type.cpp
index deb2186..7dd85d1 100644
--- a/clang/lib/AST/Type.cpp
+++ b/clang/lib/AST/Type.cpp
@@ -2137,7 +2137,8 @@
bool Type::isArithmeticType() const {
if (const auto *BT = dyn_cast<BuiltinType>(CanonicalType))
return BT->getKind() >= BuiltinType::Bool &&
- BT->getKind() <= BuiltinType::Float128;
+ BT->getKind() <= BuiltinType::Float128 &&
+ BT->getKind() != BuiltinType::BFloat16;
if (const auto *ET = dyn_cast<EnumType>(CanonicalType))
// GCC allows forward declaration of enum types (forbid by C99 6.7.2.3p2).
// If a body isn't seen by the time we get here, return false.
@@ -2922,6 +2923,8 @@
return "unsigned __int128";
case Half:
return Policy.Half ? "half" : "__fp16";
+ case BFloat16:
+ return "__bf16";
case Float:
return "float";
case Double:
diff --git a/clang/lib/AST/TypeLoc.cpp b/clang/lib/AST/TypeLoc.cpp
index 366f4d8..57c11ca 100644
--- a/clang/lib/AST/TypeLoc.cpp
+++ b/clang/lib/AST/TypeLoc.cpp
@@ -375,6 +375,7 @@
case BuiltinType::SatUShortFract:
case BuiltinType::SatUFract:
case BuiltinType::SatULongFract:
+ case BuiltinType::BFloat16:
llvm_unreachable("Builtin type needs extra local data!");
// Fall through, if the impossible happens.