[AMDGPU] Add support for Whole Wavefront Mode
Summary:
Whole Wavefront Wode (WWM) is similar to WQM, except that all of the
lanes are always enabled, regardless of control flow. This is required
for implementing wavefront reductions in non-uniform control flow, where
we need to use the inactive lanes to propagate intermediate results, so
they need to be enabled. We need to propagate WWM to uses (unless
they're explicitly marked as exact) so that they also propagate
intermediate results correctly. We do the analysis and exec mask munging
during the WQM pass, since there are interactions with WQM for things
that require both WQM and WWM. For simplicity, WWM is entirely
block-local -- blocks are never WWM on entry or exit of a block, and WWM
is not propagated to the block level. This means that computations
involving WWM cannot involve control flow, but we only ever plan to use
WWM for a few limited purposes (none of which involve control flow)
anyways.
Shaders can ask for WWM using the @llvm.amdgcn.wwm intrinsic. There
isn't yet a way to turn WWM off -- that will be added in a future
change.
Finally, it turns out that turning on inactive lanes causes a number of
problems with register allocation. While the best long-term solution
seems like teaching LLVM's register allocator about predication, for now
we need to add some hacks to prevent ourselves from getting into trouble
due to constraints that aren't currently expressed in LLVM. For the gory
details, see the comments at the top of SIFixWWMLiveness.cpp.
Reviewers: arsenm, nhaehnle, tpr
Subscribers: kzhuravl, wdng, mgorny, yaxunl, dstuttard, t-tye, llvm-commits
Differential Revision: https://reviews.llvm.org/D35524
llvm-svn: 310087
diff --git a/llvm/lib/Target/AMDGPU/SIWholeQuadMode.cpp b/llvm/lib/Target/AMDGPU/SIWholeQuadMode.cpp
index 6d91a7b..1a0f0f9 100644
--- a/llvm/lib/Target/AMDGPU/SIWholeQuadMode.cpp
+++ b/llvm/lib/Target/AMDGPU/SIWholeQuadMode.cpp
@@ -9,7 +9,7 @@
//
/// \file
/// \brief This pass adds instructions to enable whole quad mode for pixel
-/// shaders.
+/// shaders, and whole wavefront mode for all programs.
///
/// Whole quad mode is required for derivative computations, but it interferes
/// with shader side effects (stores and atomics). This pass is run on the
@@ -29,6 +29,13 @@
/// ...
/// S_MOV_B64 EXEC, Tmp
///
+/// We also compute when a sequence of instructions requires Whole Wavefront
+/// Mode (WWM) and insert instructions to save and restore it:
+///
+/// S_OR_SAVEEXEC_B64 Tmp, -1
+/// ...
+/// S_MOV_B64 EXEC, Tmp
+///
/// In order to avoid excessive switching during sequences of Exact
/// instructions, the pass first analyzes which instructions must be run in WQM
/// (aka which instructions produce values that lead to derivative
@@ -85,7 +92,8 @@
enum {
StateWQM = 0x1,
- StateExact = 0x2,
+ StateWWM = 0x2,
+ StateExact = 0x4,
};
struct PrintState {
@@ -98,9 +106,14 @@
static raw_ostream &operator<<(raw_ostream &OS, const PrintState &PS) {
if (PS.State & StateWQM)
OS << "WQM";
- if (PS.State & StateExact) {
+ if (PS.State & StateWWM) {
if (PS.State & StateWQM)
OS << '|';
+ OS << "WWM";
+ }
+ if (PS.State & StateExact) {
+ if (PS.State & (StateWQM | StateWWM))
+ OS << '|';
OS << "Exact";
}
@@ -130,6 +143,7 @@
class SIWholeQuadMode : public MachineFunctionPass {
private:
+ CallingConv::ID CallingConv;
const SIInstrInfo *TII;
const SIRegisterInfo *TRI;
MachineRegisterInfo *MRI;
@@ -163,6 +177,10 @@
unsigned SaveWQM, unsigned LiveMaskReg);
void toWQM(MachineBasicBlock &MBB, MachineBasicBlock::iterator Before,
unsigned SavedWQM);
+ void toWWM(MachineBasicBlock &MBB, MachineBasicBlock::iterator Before,
+ unsigned SaveOrig);
+ void fromWWM(MachineBasicBlock &MBB, MachineBasicBlock::iterator Before,
+ unsigned SavedOrig);
void processBlock(MachineBasicBlock &MBB, unsigned LiveMaskReg, bool isEntry);
void lowerLiveMaskQueries(unsigned LiveMaskReg);
@@ -223,7 +241,7 @@
std::vector<WorkItem> &Worklist) {
InstrInfo &II = Instructions[&MI];
- assert(Flag == StateWQM);
+ assert(!(Flag & StateExact) && Flag != 0);
// Remove any disabled states from the flag. The user that required it gets
// an undefined value in the helper lanes. For example, this can happen if
@@ -243,7 +261,6 @@
/// Mark all instructions defining the uses in \p MI with \p Flag.
void SIWholeQuadMode::markInstructionUses(const MachineInstr &MI, char Flag,
std::vector<WorkItem> &Worklist) {
- assert(Flag == StateWQM);
for (const MachineOperand &Use : MI.uses()) {
if (!Use.isReg() || !Use.isUse())
continue;
@@ -302,7 +319,7 @@
unsigned Opcode = MI.getOpcode();
char Flags = 0;
- if (TII->isDS(Opcode)) {
+ if (TII->isDS(Opcode) && CallingConv == CallingConv::AMDGPU_PS) {
Flags = StateWQM;
} else if (TII->isWQM(Opcode)) {
// Sampling instructions don't need to produce results for all pixels
@@ -316,6 +333,14 @@
// correct, so we need it to be in WQM.
Flags = StateWQM;
LowerToCopyInstrs.push_back(&MI);
+ } else if (Opcode == AMDGPU::WWM) {
+ // The WWM intrinsic doesn't make the same guarantee, and plus it needs
+ // to be executed in WQM or Exact so that its copy doesn't clobber
+ // inactive lanes.
+ markInstructionUses(MI, StateWWM, Worklist);
+ GlobalFlags |= StateWWM;
+ LowerToCopyInstrs.push_back(&MI);
+ continue;
} else if (TII->isDisableWQM(MI)) {
BBI.Needs |= StateExact;
if (!(BBI.InNeeds & StateExact)) {
@@ -323,7 +348,7 @@
Worklist.push_back(&MBB);
}
GlobalFlags |= StateExact;
- III.Disabled = StateWQM;
+ III.Disabled = StateWQM | StateWWM;
continue;
} else {
if (Opcode == AMDGPU::SI_PS_LIVE) {
@@ -383,7 +408,7 @@
// Propagate backwards within block
if (MachineInstr *PrevMI = MI.getPrevNode()) {
- char InNeeds = II.Needs | II.OutNeeds;
+ char InNeeds = (II.Needs & ~StateWWM) | II.OutNeeds;
if (!PrevMI->isPHI()) {
InstrInfo &PrevII = Instructions[PrevMI];
if ((PrevII.OutNeeds | InNeeds) != PrevII.OutNeeds) {
@@ -589,6 +614,29 @@
LIS->InsertMachineInstrInMaps(*MI);
}
+void SIWholeQuadMode::toWWM(MachineBasicBlock &MBB,
+ MachineBasicBlock::iterator Before,
+ unsigned SaveOrig) {
+ MachineInstr *MI;
+
+ assert(SaveOrig);
+ MI = BuildMI(MBB, Before, DebugLoc(), TII->get(AMDGPU::S_OR_SAVEEXEC_B64),
+ SaveOrig)
+ .addImm(-1);
+ LIS->InsertMachineInstrInMaps(*MI);
+}
+
+void SIWholeQuadMode::fromWWM(MachineBasicBlock &MBB,
+ MachineBasicBlock::iterator Before,
+ unsigned SavedOrig) {
+ MachineInstr *MI;
+
+ assert(SavedOrig);
+ MI = BuildMI(MBB, Before, DebugLoc(), TII->get(AMDGPU::EXIT_WWM), AMDGPU::EXEC)
+ .addReg(SavedOrig);
+ LIS->InsertMachineInstrInMaps(*MI);
+}
+
void SIWholeQuadMode::processBlock(MachineBasicBlock &MBB, unsigned LiveMaskReg,
bool isEntry) {
auto BII = Blocks.find(&MBB);
@@ -597,45 +645,63 @@
const BlockInfo &BI = BII->second;
- if (!(BI.InNeeds & StateWQM))
- return;
-
// This is a non-entry block that is WQM throughout, so no need to do
// anything.
- if (!isEntry && !(BI.Needs & StateExact) && BI.OutNeeds != StateExact)
+ if (!isEntry && BI.Needs == StateWQM && BI.OutNeeds != StateExact)
return;
DEBUG(dbgs() << "\nProcessing block BB#" << MBB.getNumber() << ":\n");
unsigned SavedWQMReg = 0;
+ unsigned SavedNonWWMReg = 0;
bool WQMFromExec = isEntry;
- char State = isEntry ? StateExact : StateWQM;
+ char State = (isEntry || !(BI.InNeeds & StateWQM)) ? StateExact : StateWQM;
+ char NonWWMState = 0;
auto II = MBB.getFirstNonPHI(), IE = MBB.end();
if (isEntry)
++II; // Skip the instruction that saves LiveMask
- MachineBasicBlock::iterator First = IE;
+ // This stores the first instruction where it's safe to switch from WQM to
+ // Exact or vice versa.
+ MachineBasicBlock::iterator FirstWQM = IE;
+
+ // This stores the first instruction where it's safe to switch from WWM to
+ // Exact/WQM or to switch to WWM. It must always be the same as, or after,
+ // FirstWQM since if it's safe to switch to/from WWM, it must be safe to
+ // switch to/from WQM as well.
+ MachineBasicBlock::iterator FirstWWM = IE;
for (;;) {
MachineBasicBlock::iterator Next = II;
- char Needs = StateExact | StateWQM;
+ char Needs = StateExact | StateWQM; // WWM is disabled by default
char OutNeeds = 0;
- if (First == IE)
- First = II;
+ if (FirstWQM == IE)
+ FirstWQM = II;
+ if (FirstWWM == IE)
+ FirstWWM = II;
+
+ // First, figure out the allowed states (Needs) based on the propagated
+ // flags.
if (II != IE) {
MachineInstr &MI = *II;
if (requiresCorrectState(MI)) {
auto III = Instructions.find(&MI);
if (III != Instructions.end()) {
- if (III->second.Needs & StateWQM)
+ if (III->second.Needs & StateWWM)
+ Needs = StateWWM;
+ else if (III->second.Needs & StateWQM)
Needs = StateWQM;
else
Needs &= ~III->second.Disabled;
OutNeeds = III->second.OutNeeds;
}
+ } else {
+ // If the instruction doesn't actually need a correct EXEC, then we can
+ // safely leave WWM enabled.
+ Needs = StateExact | StateWQM | StateWWM;
}
if (MI.isTerminator() && OutNeeds == StateExact)
@@ -655,35 +721,63 @@
Needs = StateWQM | StateExact;
}
+ // Now, transition if necessary.
if (!(Needs & State)) {
+ MachineBasicBlock::iterator First;
+ if (State == StateWWM || Needs == StateWWM) {
+ // We must switch to or from WWM
+ First = FirstWWM;
+ } else {
+ // We only need to switch to/from WQM, so we can use FirstWQM
+ First = FirstWQM;
+ }
+
MachineBasicBlock::iterator Before =
prepareInsertion(MBB, First, II, Needs == StateWQM,
Needs == StateExact || WQMFromExec);
- if (Needs == StateExact) {
- if (!WQMFromExec && (OutNeeds & StateWQM))
- SavedWQMReg = MRI->createVirtualRegister(&AMDGPU::SReg_64RegClass);
-
- toExact(MBB, Before, SavedWQMReg, LiveMaskReg);
- State = StateExact;
- } else {
- assert(Needs == StateWQM);
- assert(WQMFromExec == (SavedWQMReg == 0));
-
- toWQM(MBB, Before, SavedWQMReg);
-
- if (SavedWQMReg) {
- LIS->createAndComputeVirtRegInterval(SavedWQMReg);
- SavedWQMReg = 0;
- }
- State = StateWQM;
+ if (State == StateWWM) {
+ assert(SavedNonWWMReg);
+ fromWWM(MBB, Before, SavedNonWWMReg);
+ State = NonWWMState;
}
- First = IE;
+ if (Needs == StateWWM) {
+ NonWWMState = State;
+ SavedNonWWMReg = MRI->createVirtualRegister(&AMDGPU::SReg_64RegClass);
+ toWWM(MBB, Before, SavedNonWWMReg);
+ State = StateWWM;
+ } else {
+ if (State == StateWQM && (Needs & StateExact) && !(Needs & StateWQM)) {
+ if (!WQMFromExec && (OutNeeds & StateWQM))
+ SavedWQMReg = MRI->createVirtualRegister(&AMDGPU::SReg_64RegClass);
+
+ toExact(MBB, Before, SavedWQMReg, LiveMaskReg);
+ State = StateExact;
+ } else if (State == StateExact && (Needs & StateWQM) &&
+ !(Needs & StateExact)) {
+ assert(WQMFromExec == (SavedWQMReg == 0));
+
+ toWQM(MBB, Before, SavedWQMReg);
+
+ if (SavedWQMReg) {
+ LIS->createAndComputeVirtRegInterval(SavedWQMReg);
+ SavedWQMReg = 0;
+ }
+ State = StateWQM;
+ } else {
+ // We can get here if we transitioned from WWM to a non-WWM state that
+ // already matches our needs, but we shouldn't need to do anything.
+ assert(Needs & State);
+ }
+ }
}
- if (Needs != (StateExact | StateWQM))
- First = IE;
+ if (Needs != (StateExact | StateWQM | StateWWM)) {
+ if (Needs != (StateExact | StateWQM))
+ FirstWQM = IE;
+ FirstWWM = IE;
+ }
if (II == IE)
break;
@@ -710,13 +804,11 @@
}
bool SIWholeQuadMode::runOnMachineFunction(MachineFunction &MF) {
- if (MF.getFunction()->getCallingConv() != CallingConv::AMDGPU_PS)
- return false;
-
Instructions.clear();
Blocks.clear();
LiveMaskQueries.clear();
LowerToCopyInstrs.clear();
+ CallingConv = MF.getFunction()->getCallingConv();
const SISubtarget &ST = MF.getSubtarget<SISubtarget>();
@@ -726,14 +818,13 @@
LIS = &getAnalysis<LiveIntervals>();
char GlobalFlags = analyzeFunction(MF);
+ unsigned LiveMaskReg = 0;
if (!(GlobalFlags & StateWQM)) {
lowerLiveMaskQueries(AMDGPU::EXEC);
- return !LiveMaskQueries.empty();
- }
-
- // Store a copy of the original live mask when required
- unsigned LiveMaskReg = 0;
- {
+ if (!(GlobalFlags & StateWWM))
+ return !LiveMaskQueries.empty();
+ } else {
+ // Store a copy of the original live mask when required
MachineBasicBlock &Entry = MF.front();
MachineBasicBlock::iterator EntryMI = Entry.getFirstNonPHI();
@@ -745,13 +836,14 @@
LIS->InsertMachineInstrInMaps(*MI);
}
+ lowerLiveMaskQueries(LiveMaskReg);
+
if (GlobalFlags == StateWQM) {
// For a shader that needs only WQM, we can just set it once.
BuildMI(Entry, EntryMI, DebugLoc(), TII->get(AMDGPU::S_WQM_B64),
AMDGPU::EXEC)
.addReg(AMDGPU::EXEC);
- lowerLiveMaskQueries(LiveMaskReg);
lowerCopyInstrs();
// EntryMI may become invalid here
return true;
@@ -760,7 +852,6 @@
DEBUG(printInfo());
- lowerLiveMaskQueries(LiveMaskReg);
lowerCopyInstrs();
// Handle the general case