X86CondBrFolding.cpp revision 360784
1//===---- X86CondBrFolding.cpp - optimize conditional branches ------------===// 2// 3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4// See https://llvm.org/LICENSE.txt for license information. 5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6// 7//===----------------------------------------------------------------------===// 8// This file defines a pass that optimizes condition branches on x86 by taking 9// advantage of the three-way conditional code generated by compare 10// instructions. 11// Currently, it tries to hoisting EQ and NE conditional branch to a dominant 12// conditional branch condition where the same EQ/NE conditional code is 13// computed. An example: 14// bb_0: 15// cmp %0, 19 16// jg bb_1 17// jmp bb_2 18// bb_1: 19// cmp %0, 40 20// jg bb_3 21// jmp bb_4 22// bb_4: 23// cmp %0, 20 24// je bb_5 25// jmp bb_6 26// Here we could combine the two compares in bb_0 and bb_4 and have the 27// following code: 28// bb_0: 29// cmp %0, 20 30// jg bb_1 31// jl bb_2 32// jmp bb_5 33// bb_1: 34// cmp %0, 40 35// jg bb_3 36// jmp bb_6 37// For the case of %0 == 20 (bb_5), we eliminate two jumps, and the control 38// height for bb_6 is also reduced. bb_4 is gone after the optimization. 39// 40// There are plenty of this code patterns, especially from the switch case 41// lowing where we generate compare of "pivot-1" for the inner nodes in the 42// binary search tree. 43//===----------------------------------------------------------------------===// 44 45#include "X86.h" 46#include "X86InstrInfo.h" 47#include "X86Subtarget.h" 48#include "llvm/ADT/Statistic.h" 49#include "llvm/CodeGen/MachineBranchProbabilityInfo.h" 50#include "llvm/CodeGen/MachineFunctionPass.h" 51#include "llvm/CodeGen/MachineInstrBuilder.h" 52#include "llvm/CodeGen/MachineRegisterInfo.h" 53#include "llvm/Support/BranchProbability.h" 54 55using namespace llvm; 56 57#define DEBUG_TYPE "x86-condbr-folding" 58 59STATISTIC(NumFixedCondBrs, "Number of x86 condbr folded"); 60 61namespace { 62class X86CondBrFoldingPass : public MachineFunctionPass { 63public: 64 X86CondBrFoldingPass() : MachineFunctionPass(ID) { } 65 StringRef getPassName() const override { return "X86 CondBr Folding"; } 66 67 bool runOnMachineFunction(MachineFunction &MF) override; 68 69 void getAnalysisUsage(AnalysisUsage &AU) const override { 70 MachineFunctionPass::getAnalysisUsage(AU); 71 AU.addRequired<MachineBranchProbabilityInfo>(); 72 } 73 74public: 75 static char ID; 76}; 77} // namespace 78 79char X86CondBrFoldingPass::ID = 0; 80INITIALIZE_PASS(X86CondBrFoldingPass, "X86CondBrFolding", "X86CondBrFolding", false, false) 81 82FunctionPass *llvm::createX86CondBrFolding() { 83 return new X86CondBrFoldingPass(); 84} 85 86namespace { 87// A class the stores the auxiliary information for each MBB. 88struct TargetMBBInfo { 89 MachineBasicBlock *TBB; 90 MachineBasicBlock *FBB; 91 MachineInstr *BrInstr; 92 MachineInstr *CmpInstr; 93 X86::CondCode BranchCode; 94 unsigned SrcReg; 95 int CmpValue; 96 bool Modified; 97 bool CmpBrOnly; 98}; 99 100// A class that optimizes the conditional branch by hoisting and merge CondCode. 101class X86CondBrFolding { 102public: 103 X86CondBrFolding(const X86InstrInfo *TII, 104 const MachineBranchProbabilityInfo *MBPI, 105 MachineFunction &MF) 106 : TII(TII), MBPI(MBPI), MF(MF) {} 107 bool optimize(); 108 109private: 110 const X86InstrInfo *TII; 111 const MachineBranchProbabilityInfo *MBPI; 112 MachineFunction &MF; 113 std::vector<std::unique_ptr<TargetMBBInfo>> MBBInfos; 114 SmallVector<MachineBasicBlock *, 4> RemoveList; 115 116 void optimizeCondBr(MachineBasicBlock &MBB, 117 SmallVectorImpl<MachineBasicBlock *> &BranchPath); 118 void replaceBrDest(MachineBasicBlock *MBB, MachineBasicBlock *OrigDest, 119 MachineBasicBlock *NewDest); 120 void fixupModifiedCond(MachineBasicBlock *MBB); 121 std::unique_ptr<TargetMBBInfo> analyzeMBB(MachineBasicBlock &MBB); 122 static bool analyzeCompare(const MachineInstr &MI, unsigned &SrcReg, 123 int &CmpValue); 124 bool findPath(MachineBasicBlock *MBB, 125 SmallVectorImpl<MachineBasicBlock *> &BranchPath); 126 TargetMBBInfo *getMBBInfo(MachineBasicBlock *MBB) const { 127 return MBBInfos[MBB->getNumber()].get(); 128 } 129}; 130} // namespace 131 132// Find a valid path that we can reuse the CondCode. 133// The resulted path (if return true) is stored in BranchPath. 134// Return value: 135// false: is no valid path is found. 136// true: a valid path is found and the targetBB can be reached. 137bool X86CondBrFolding::findPath( 138 MachineBasicBlock *MBB, SmallVectorImpl<MachineBasicBlock *> &BranchPath) { 139 TargetMBBInfo *MBBInfo = getMBBInfo(MBB); 140 assert(MBBInfo && "Expecting a candidate MBB"); 141 int CmpValue = MBBInfo->CmpValue; 142 143 MachineBasicBlock *PredMBB = *MBB->pred_begin(); 144 MachineBasicBlock *SaveMBB = MBB; 145 while (PredMBB) { 146 TargetMBBInfo *PredMBBInfo = getMBBInfo(PredMBB); 147 if (!PredMBBInfo || PredMBBInfo->SrcReg != MBBInfo->SrcReg) 148 return false; 149 150 assert(SaveMBB == PredMBBInfo->TBB || SaveMBB == PredMBBInfo->FBB); 151 bool IsFalseBranch = (SaveMBB == PredMBBInfo->FBB); 152 153 X86::CondCode CC = PredMBBInfo->BranchCode; 154 assert(CC == X86::COND_L || CC == X86::COND_G || CC == X86::COND_E); 155 int PredCmpValue = PredMBBInfo->CmpValue; 156 bool ValueCmpTrue = ((CmpValue < PredCmpValue && CC == X86::COND_L) || 157 (CmpValue > PredCmpValue && CC == X86::COND_G) || 158 (CmpValue == PredCmpValue && CC == X86::COND_E)); 159 // Check if both the result of value compare and the branch target match. 160 if (!(ValueCmpTrue ^ IsFalseBranch)) { 161 LLVM_DEBUG(dbgs() << "Dead BB detected!\n"); 162 return false; 163 } 164 165 BranchPath.push_back(PredMBB); 166 // These are the conditions on which we could combine the compares. 167 if ((CmpValue == PredCmpValue) || 168 (CmpValue == PredCmpValue - 1 && CC == X86::COND_L) || 169 (CmpValue == PredCmpValue + 1 && CC == X86::COND_G)) 170 return true; 171 172 // If PredMBB has more than on preds, or not a pure cmp and br, we bailout. 173 if (PredMBB->pred_size() != 1 || !PredMBBInfo->CmpBrOnly) 174 return false; 175 176 SaveMBB = PredMBB; 177 PredMBB = *PredMBB->pred_begin(); 178 } 179 return false; 180} 181 182// Fix up any PHI node in the successor of MBB. 183static void fixPHIsInSucc(MachineBasicBlock *MBB, MachineBasicBlock *OldMBB, 184 MachineBasicBlock *NewMBB) { 185 if (NewMBB == OldMBB) 186 return; 187 for (auto MI = MBB->instr_begin(), ME = MBB->instr_end(); 188 MI != ME && MI->isPHI(); ++MI) 189 for (unsigned i = 2, e = MI->getNumOperands() + 1; i != e; i += 2) { 190 MachineOperand &MO = MI->getOperand(i); 191 if (MO.getMBB() == OldMBB) 192 MO.setMBB(NewMBB); 193 } 194} 195 196// Utility function to set branch probability for edge MBB->SuccMBB. 197static inline bool setBranchProb(MachineBasicBlock *MBB, 198 MachineBasicBlock *SuccMBB, 199 BranchProbability Prob) { 200 auto MBBI = std::find(MBB->succ_begin(), MBB->succ_end(), SuccMBB); 201 if (MBBI == MBB->succ_end()) 202 return false; 203 MBB->setSuccProbability(MBBI, Prob); 204 return true; 205} 206 207// Utility function to find the unconditional br instruction in MBB. 208static inline MachineBasicBlock::iterator 209findUncondBrI(MachineBasicBlock *MBB) { 210 return std::find_if(MBB->begin(), MBB->end(), [](MachineInstr &MI) -> bool { 211 return MI.getOpcode() == X86::JMP_1; 212 }); 213} 214 215// Replace MBB's original successor, OrigDest, with NewDest. 216// Also update the MBBInfo for MBB. 217void X86CondBrFolding::replaceBrDest(MachineBasicBlock *MBB, 218 MachineBasicBlock *OrigDest, 219 MachineBasicBlock *NewDest) { 220 TargetMBBInfo *MBBInfo = getMBBInfo(MBB); 221 MachineInstr *BrMI; 222 if (MBBInfo->TBB == OrigDest) { 223 BrMI = MBBInfo->BrInstr; 224 MachineInstrBuilder MIB = 225 BuildMI(*MBB, BrMI, MBB->findDebugLoc(BrMI), TII->get(X86::JCC_1)) 226 .addMBB(NewDest).addImm(MBBInfo->BranchCode); 227 MBBInfo->TBB = NewDest; 228 MBBInfo->BrInstr = MIB.getInstr(); 229 } else { // Should be the unconditional jump stmt. 230 MachineBasicBlock::iterator UncondBrI = findUncondBrI(MBB); 231 BuildMI(*MBB, UncondBrI, MBB->findDebugLoc(UncondBrI), TII->get(X86::JMP_1)) 232 .addMBB(NewDest); 233 MBBInfo->FBB = NewDest; 234 BrMI = &*UncondBrI; 235 } 236 fixPHIsInSucc(NewDest, OrigDest, MBB); 237 BrMI->eraseFromParent(); 238 MBB->addSuccessor(NewDest); 239 setBranchProb(MBB, NewDest, MBPI->getEdgeProbability(MBB, OrigDest)); 240 MBB->removeSuccessor(OrigDest); 241} 242 243// Change the CondCode and BrInstr according to MBBInfo. 244void X86CondBrFolding::fixupModifiedCond(MachineBasicBlock *MBB) { 245 TargetMBBInfo *MBBInfo = getMBBInfo(MBB); 246 if (!MBBInfo->Modified) 247 return; 248 249 MachineInstr *BrMI = MBBInfo->BrInstr; 250 X86::CondCode CC = MBBInfo->BranchCode; 251 MachineInstrBuilder MIB = BuildMI(*MBB, BrMI, MBB->findDebugLoc(BrMI), 252 TII->get(X86::JCC_1)) 253 .addMBB(MBBInfo->TBB).addImm(CC); 254 BrMI->eraseFromParent(); 255 MBBInfo->BrInstr = MIB.getInstr(); 256 257 MachineBasicBlock::iterator UncondBrI = findUncondBrI(MBB); 258 BuildMI(*MBB, UncondBrI, MBB->findDebugLoc(UncondBrI), TII->get(X86::JMP_1)) 259 .addMBB(MBBInfo->FBB); 260 MBB->erase(UncondBrI); 261 MBBInfo->Modified = false; 262} 263 264// 265// Apply the transformation: 266// RootMBB -1-> ... PredMBB -3-> MBB -5-> TargetMBB 267// \-2-> \-4-> \-6-> FalseMBB 268// ==> 269// RootMBB -1-> ... PredMBB -7-> FalseMBB 270// TargetMBB <-8-/ \-2-> \-4-> 271// 272// Note that PredMBB and RootMBB could be the same. 273// And in the case of dead TargetMBB, we will not have TargetMBB and edge 8. 274// 275// There are some special handling where the RootMBB is COND_E in which case 276// we directly short-cycle the brinstr. 277// 278void X86CondBrFolding::optimizeCondBr( 279 MachineBasicBlock &MBB, SmallVectorImpl<MachineBasicBlock *> &BranchPath) { 280 281 X86::CondCode CC; 282 TargetMBBInfo *MBBInfo = getMBBInfo(&MBB); 283 assert(MBBInfo && "Expecting a candidate MBB"); 284 MachineBasicBlock *TargetMBB = MBBInfo->TBB; 285 BranchProbability TargetProb = MBPI->getEdgeProbability(&MBB, MBBInfo->TBB); 286 287 // Forward the jump from MBB's predecessor to MBB's false target. 288 MachineBasicBlock *PredMBB = BranchPath.front(); 289 TargetMBBInfo *PredMBBInfo = getMBBInfo(PredMBB); 290 assert(PredMBBInfo && "Expecting a candidate MBB"); 291 if (PredMBBInfo->Modified) 292 fixupModifiedCond(PredMBB); 293 CC = PredMBBInfo->BranchCode; 294 // Don't do this if depth of BranchPath is 1 and PredMBB is of COND_E. 295 // We will short-cycle directly for this case. 296 if (!(CC == X86::COND_E && BranchPath.size() == 1)) 297 replaceBrDest(PredMBB, &MBB, MBBInfo->FBB); 298 299 MachineBasicBlock *RootMBB = BranchPath.back(); 300 TargetMBBInfo *RootMBBInfo = getMBBInfo(RootMBB); 301 assert(RootMBBInfo && "Expecting a candidate MBB"); 302 if (RootMBBInfo->Modified) 303 fixupModifiedCond(RootMBB); 304 CC = RootMBBInfo->BranchCode; 305 306 if (CC != X86::COND_E) { 307 MachineBasicBlock::iterator UncondBrI = findUncondBrI(RootMBB); 308 // RootMBB: Cond jump to the original not-taken MBB. 309 X86::CondCode NewCC; 310 switch (CC) { 311 case X86::COND_L: 312 NewCC = X86::COND_G; 313 break; 314 case X86::COND_G: 315 NewCC = X86::COND_L; 316 break; 317 default: 318 llvm_unreachable("unexpected condtional code."); 319 } 320 BuildMI(*RootMBB, UncondBrI, RootMBB->findDebugLoc(UncondBrI), 321 TII->get(X86::JCC_1)) 322 .addMBB(RootMBBInfo->FBB).addImm(NewCC); 323 324 // RootMBB: Jump to TargetMBB 325 BuildMI(*RootMBB, UncondBrI, RootMBB->findDebugLoc(UncondBrI), 326 TII->get(X86::JMP_1)) 327 .addMBB(TargetMBB); 328 RootMBB->addSuccessor(TargetMBB); 329 fixPHIsInSucc(TargetMBB, &MBB, RootMBB); 330 RootMBB->erase(UncondBrI); 331 } else { 332 replaceBrDest(RootMBB, RootMBBInfo->TBB, TargetMBB); 333 } 334 335 // Fix RootMBB's CmpValue to MBB's CmpValue to TargetMBB. Don't set Imm 336 // directly. Move MBB's stmt to here as the opcode might be different. 337 if (RootMBBInfo->CmpValue != MBBInfo->CmpValue) { 338 MachineInstr *NewCmp = MBBInfo->CmpInstr; 339 NewCmp->removeFromParent(); 340 RootMBB->insert(RootMBBInfo->CmpInstr, NewCmp); 341 RootMBBInfo->CmpInstr->eraseFromParent(); 342 } 343 344 // Fix branch Probabilities. 345 auto fixBranchProb = [&](MachineBasicBlock *NextMBB) { 346 BranchProbability Prob; 347 for (auto &I : BranchPath) { 348 MachineBasicBlock *ThisMBB = I; 349 if (!ThisMBB->hasSuccessorProbabilities() || 350 !ThisMBB->isSuccessor(NextMBB)) 351 break; 352 Prob = MBPI->getEdgeProbability(ThisMBB, NextMBB); 353 if (Prob.isUnknown()) 354 break; 355 TargetProb = Prob * TargetProb; 356 Prob = Prob - TargetProb; 357 setBranchProb(ThisMBB, NextMBB, Prob); 358 if (ThisMBB == RootMBB) { 359 setBranchProb(ThisMBB, TargetMBB, TargetProb); 360 } 361 ThisMBB->normalizeSuccProbs(); 362 if (ThisMBB == RootMBB) 363 break; 364 NextMBB = ThisMBB; 365 } 366 return true; 367 }; 368 if (CC != X86::COND_E && !TargetProb.isUnknown()) 369 fixBranchProb(MBBInfo->FBB); 370 371 if (CC != X86::COND_E) 372 RemoveList.push_back(&MBB); 373 374 // Invalidate MBBInfo just in case. 375 MBBInfos[MBB.getNumber()] = nullptr; 376 MBBInfos[RootMBB->getNumber()] = nullptr; 377 378 LLVM_DEBUG(dbgs() << "After optimization:\nRootMBB is: " << *RootMBB << "\n"); 379 if (BranchPath.size() > 1) 380 LLVM_DEBUG(dbgs() << "PredMBB is: " << *(BranchPath[0]) << "\n"); 381} 382 383// Driver function for optimization: find the valid candidate and apply 384// the transformation. 385bool X86CondBrFolding::optimize() { 386 bool Changed = false; 387 LLVM_DEBUG(dbgs() << "***** X86CondBr Folding on Function: " << MF.getName() 388 << " *****\n"); 389 // Setup data structures. 390 MBBInfos.resize(MF.getNumBlockIDs()); 391 for (auto &MBB : MF) 392 MBBInfos[MBB.getNumber()] = analyzeMBB(MBB); 393 394 for (auto &MBB : MF) { 395 TargetMBBInfo *MBBInfo = getMBBInfo(&MBB); 396 if (!MBBInfo || !MBBInfo->CmpBrOnly) 397 continue; 398 if (MBB.pred_size() != 1) 399 continue; 400 LLVM_DEBUG(dbgs() << "Work on MBB." << MBB.getNumber() 401 << " CmpValue: " << MBBInfo->CmpValue << "\n"); 402 SmallVector<MachineBasicBlock *, 4> BranchPath; 403 if (!findPath(&MBB, BranchPath)) 404 continue; 405 406#ifndef NDEBUG 407 LLVM_DEBUG(dbgs() << "Found one path (len=" << BranchPath.size() << "):\n"); 408 int Index = 1; 409 LLVM_DEBUG(dbgs() << "Target MBB is: " << MBB << "\n"); 410 for (auto I = BranchPath.rbegin(); I != BranchPath.rend(); ++I, ++Index) { 411 MachineBasicBlock *PMBB = *I; 412 TargetMBBInfo *PMBBInfo = getMBBInfo(PMBB); 413 LLVM_DEBUG(dbgs() << "Path MBB (" << Index << " of " << BranchPath.size() 414 << ") is " << *PMBB); 415 LLVM_DEBUG(dbgs() << "CC=" << PMBBInfo->BranchCode 416 << " Val=" << PMBBInfo->CmpValue 417 << " CmpBrOnly=" << PMBBInfo->CmpBrOnly << "\n\n"); 418 } 419#endif 420 optimizeCondBr(MBB, BranchPath); 421 Changed = true; 422 } 423 NumFixedCondBrs += RemoveList.size(); 424 for (auto MBBI : RemoveList) { 425 while (!MBBI->succ_empty()) 426 MBBI->removeSuccessor(MBBI->succ_end() - 1); 427 428 MBBI->eraseFromParent(); 429 } 430 431 return Changed; 432} 433 434// Analyze instructions that generate CondCode and extract information. 435bool X86CondBrFolding::analyzeCompare(const MachineInstr &MI, unsigned &SrcReg, 436 int &CmpValue) { 437 unsigned SrcRegIndex = 0; 438 unsigned ValueIndex = 0; 439 switch (MI.getOpcode()) { 440 // TODO: handle test instructions. 441 default: 442 return false; 443 case X86::CMP64ri32: 444 case X86::CMP64ri8: 445 case X86::CMP32ri: 446 case X86::CMP32ri8: 447 case X86::CMP16ri: 448 case X86::CMP16ri8: 449 case X86::CMP8ri: 450 SrcRegIndex = 0; 451 ValueIndex = 1; 452 break; 453 case X86::SUB64ri32: 454 case X86::SUB64ri8: 455 case X86::SUB32ri: 456 case X86::SUB32ri8: 457 case X86::SUB16ri: 458 case X86::SUB16ri8: 459 case X86::SUB8ri: 460 SrcRegIndex = 1; 461 ValueIndex = 2; 462 break; 463 } 464 SrcReg = MI.getOperand(SrcRegIndex).getReg(); 465 if (!MI.getOperand(ValueIndex).isImm()) 466 return false; 467 CmpValue = MI.getOperand(ValueIndex).getImm(); 468 return true; 469} 470 471// Analyze a candidate MBB and set the extract all the information needed. 472// The valid candidate will have two successors. 473// It also should have a sequence of 474// Branch_instr, 475// CondBr, 476// UnCondBr. 477// Return TargetMBBInfo if MBB is a valid candidate and nullptr otherwise. 478std::unique_ptr<TargetMBBInfo> 479X86CondBrFolding::analyzeMBB(MachineBasicBlock &MBB) { 480 MachineBasicBlock *TBB; 481 MachineBasicBlock *FBB; 482 MachineInstr *BrInstr; 483 MachineInstr *CmpInstr; 484 X86::CondCode CC; 485 unsigned SrcReg; 486 int CmpValue; 487 bool Modified; 488 bool CmpBrOnly; 489 490 if (MBB.succ_size() != 2) 491 return nullptr; 492 493 CmpBrOnly = true; 494 FBB = TBB = nullptr; 495 CmpInstr = nullptr; 496 MachineBasicBlock::iterator I = MBB.end(); 497 while (I != MBB.begin()) { 498 --I; 499 if (I->isDebugValue()) 500 continue; 501 if (I->getOpcode() == X86::JMP_1) { 502 if (FBB) 503 return nullptr; 504 FBB = I->getOperand(0).getMBB(); 505 continue; 506 } 507 if (I->isBranch()) { 508 if (TBB) 509 return nullptr; 510 CC = X86::getCondFromBranch(*I); 511 switch (CC) { 512 default: 513 return nullptr; 514 case X86::COND_E: 515 case X86::COND_L: 516 case X86::COND_G: 517 case X86::COND_NE: 518 case X86::COND_LE: 519 case X86::COND_GE: 520 break; 521 } 522 TBB = I->getOperand(0).getMBB(); 523 BrInstr = &*I; 524 continue; 525 } 526 if (analyzeCompare(*I, SrcReg, CmpValue)) { 527 if (CmpInstr) 528 return nullptr; 529 CmpInstr = &*I; 530 continue; 531 } 532 CmpBrOnly = false; 533 break; 534 } 535 536 if (!TBB || !FBB || !CmpInstr) 537 return nullptr; 538 539 // Simplify CondCode. Note this is only to simplify the findPath logic 540 // and will not change the instruction here. 541 switch (CC) { 542 case X86::COND_NE: 543 CC = X86::COND_E; 544 std::swap(TBB, FBB); 545 Modified = true; 546 break; 547 case X86::COND_LE: 548 if (CmpValue == INT_MAX) 549 return nullptr; 550 CC = X86::COND_L; 551 CmpValue += 1; 552 Modified = true; 553 break; 554 case X86::COND_GE: 555 if (CmpValue == INT_MIN) 556 return nullptr; 557 CC = X86::COND_G; 558 CmpValue -= 1; 559 Modified = true; 560 break; 561 default: 562 Modified = false; 563 break; 564 } 565 return std::make_unique<TargetMBBInfo>(TargetMBBInfo{ 566 TBB, FBB, BrInstr, CmpInstr, CC, SrcReg, CmpValue, Modified, CmpBrOnly}); 567} 568 569bool X86CondBrFoldingPass::runOnMachineFunction(MachineFunction &MF) { 570 const X86Subtarget &ST = MF.getSubtarget<X86Subtarget>(); 571 if (!ST.threewayBranchProfitable()) 572 return false; 573 const X86InstrInfo *TII = ST.getInstrInfo(); 574 const MachineBranchProbabilityInfo *MBPI = 575 &getAnalysis<MachineBranchProbabilityInfo>(); 576 577 X86CondBrFolding CondBr(TII, MBPI, MF); 578 return CondBr.optimize(); 579} 580