CostModel.cpp revision 263508
1//===- CostModel.cpp ------ Cost Model Analysis ---------------------------===// 2// 3// The LLVM Compiler Infrastructure 4// 5// This file is distributed under the University of Illinois Open Source 6// License. See LICENSE.TXT for details. 7// 8//===----------------------------------------------------------------------===// 9// 10// This file defines the cost model analysis. It provides a very basic cost 11// estimation for LLVM-IR. This analysis uses the services of the codegen 12// to approximate the cost of any IR instruction when lowered to machine 13// instructions. The cost results are unit-less and the cost number represents 14// the throughput of the machine assuming that all loads hit the cache, all 15// branches are predicted, etc. The cost numbers can be added in order to 16// compare two or more transformation alternatives. 17// 18//===----------------------------------------------------------------------===// 19 20#define CM_NAME "cost-model" 21#define DEBUG_TYPE CM_NAME 22#include "llvm/ADT/STLExtras.h" 23#include "llvm/Analysis/Passes.h" 24#include "llvm/Analysis/TargetTransformInfo.h" 25#include "llvm/IR/Function.h" 26#include "llvm/IR/Instructions.h" 27#include "llvm/IR/IntrinsicInst.h" 28#include "llvm/IR/Value.h" 29#include "llvm/Pass.h" 30#include "llvm/Support/CommandLine.h" 31#include "llvm/Support/Debug.h" 32#include "llvm/Support/raw_ostream.h" 33using namespace llvm; 34 35static cl::opt<bool> EnableReduxCost("costmodel-reduxcost", cl::init(false), 36 cl::Hidden, 37 cl::desc("Recognize reduction patterns.")); 38 39namespace { 40 class CostModelAnalysis : public FunctionPass { 41 42 public: 43 static char ID; // Class identification, replacement for typeinfo 44 CostModelAnalysis() : FunctionPass(ID), F(0), TTI(0) { 45 initializeCostModelAnalysisPass( 46 *PassRegistry::getPassRegistry()); 47 } 48 49 /// Returns the expected cost of the instruction. 50 /// Returns -1 if the cost is unknown. 51 /// Note, this method does not cache the cost calculation and it 52 /// can be expensive in some cases. 53 unsigned getInstructionCost(const Instruction *I) const; 54 55 private: 56 virtual void getAnalysisUsage(AnalysisUsage &AU) const; 57 virtual bool runOnFunction(Function &F); 58 virtual void print(raw_ostream &OS, const Module*) const; 59 60 /// The function that we analyze. 61 Function *F; 62 /// Target information. 63 const TargetTransformInfo *TTI; 64 }; 65} // End of anonymous namespace 66 67// Register this pass. 68char CostModelAnalysis::ID = 0; 69static const char cm_name[] = "Cost Model Analysis"; 70INITIALIZE_PASS_BEGIN(CostModelAnalysis, CM_NAME, cm_name, false, true) 71INITIALIZE_PASS_END (CostModelAnalysis, CM_NAME, cm_name, false, true) 72 73FunctionPass *llvm::createCostModelAnalysisPass() { 74 return new CostModelAnalysis(); 75} 76 77void 78CostModelAnalysis::getAnalysisUsage(AnalysisUsage &AU) const { 79 AU.setPreservesAll(); 80} 81 82bool 83CostModelAnalysis::runOnFunction(Function &F) { 84 this->F = &F; 85 TTI = getAnalysisIfAvailable<TargetTransformInfo>(); 86 87 return false; 88} 89 90static bool isReverseVectorMask(SmallVectorImpl<int> &Mask) { 91 for (unsigned i = 0, MaskSize = Mask.size(); i < MaskSize; ++i) 92 if (Mask[i] > 0 && Mask[i] != (int)(MaskSize - 1 - i)) 93 return false; 94 return true; 95} 96 97static TargetTransformInfo::OperandValueKind getOperandInfo(Value *V) { 98 TargetTransformInfo::OperandValueKind OpInfo = 99 TargetTransformInfo::OK_AnyValue; 100 101 // Check for a splat of a constant. 102 ConstantDataVector *CDV = 0; 103 if ((CDV = dyn_cast<ConstantDataVector>(V))) 104 if (CDV->getSplatValue() != NULL) 105 OpInfo = TargetTransformInfo::OK_UniformConstantValue; 106 ConstantVector *CV = 0; 107 if ((CV = dyn_cast<ConstantVector>(V))) 108 if (CV->getSplatValue() != NULL) 109 OpInfo = TargetTransformInfo::OK_UniformConstantValue; 110 111 return OpInfo; 112} 113 114static bool matchMask(SmallVectorImpl<int> &M1, SmallVectorImpl<int> &M2) { 115 if (M1.size() != M2.size()) 116 return false; 117 118 for (unsigned i = 0, e = M1.size(); i != e; ++i) 119 if (M1[i] != M2[i]) 120 return false; 121 122 return true; 123} 124 125static bool matchPairwiseShuffleMask(ShuffleVectorInst *SI, bool IsLeft, 126 unsigned Level) { 127 // We don't need a shuffle if we just want to have element 0 in position 0 of 128 // the vector. 129 if (!SI && Level == 0 && IsLeft) 130 return true; 131 else if (!SI) 132 return false; 133 134 SmallVector<int, 32> Mask(SI->getType()->getVectorNumElements(), -1); 135 136 // Build a mask of 0, 2, ... (left) or 1, 3, ... (right) depending on whether 137 // we look at the left or right side. 138 for (unsigned i = 0, e = (1 << Level), val = !IsLeft; i != e; ++i, val += 2) 139 Mask[i] = val; 140 141 SmallVector<int, 16> ActualMask = SI->getShuffleMask(); 142 if (!matchMask(Mask, ActualMask)) 143 return false; 144 145 return true; 146} 147 148static bool matchPairwiseReductionAtLevel(const BinaryOperator *BinOp, 149 unsigned Level, unsigned NumLevels) { 150 // Match one level of pairwise operations. 151 // %rdx.shuf.0.0 = shufflevector <4 x float> %rdx, <4 x float> undef, 152 // <4 x i32> <i32 0, i32 2 , i32 undef, i32 undef> 153 // %rdx.shuf.0.1 = shufflevector <4 x float> %rdx, <4 x float> undef, 154 // <4 x i32> <i32 1, i32 3, i32 undef, i32 undef> 155 // %bin.rdx.0 = fadd <4 x float> %rdx.shuf.0.0, %rdx.shuf.0.1 156 if (BinOp == 0) 157 return false; 158 159 assert(BinOp->getType()->isVectorTy() && "Expecting a vector type"); 160 161 unsigned Opcode = BinOp->getOpcode(); 162 Value *L = BinOp->getOperand(0); 163 Value *R = BinOp->getOperand(1); 164 165 ShuffleVectorInst *LS = dyn_cast<ShuffleVectorInst>(L); 166 if (!LS && Level) 167 return false; 168 ShuffleVectorInst *RS = dyn_cast<ShuffleVectorInst>(R); 169 if (!RS && Level) 170 return false; 171 172 // On level 0 we can omit one shufflevector instruction. 173 if (!Level && !RS && !LS) 174 return false; 175 176 // Shuffle inputs must match. 177 Value *NextLevelOpL = LS ? LS->getOperand(0) : 0; 178 Value *NextLevelOpR = RS ? RS->getOperand(0) : 0; 179 Value *NextLevelOp = 0; 180 if (NextLevelOpR && NextLevelOpL) { 181 // If we have two shuffles their operands must match. 182 if (NextLevelOpL != NextLevelOpR) 183 return false; 184 185 NextLevelOp = NextLevelOpL; 186 } else if (Level == 0 && (NextLevelOpR || NextLevelOpL)) { 187 // On the first level we can omit the shufflevector <0, undef,...>. So the 188 // input to the other shufflevector <1, undef> must match with one of the 189 // inputs to the current binary operation. 190 // Example: 191 // %NextLevelOpL = shufflevector %R, <1, undef ...> 192 // %BinOp = fadd %NextLevelOpL, %R 193 if (NextLevelOpL && NextLevelOpL != R) 194 return false; 195 else if (NextLevelOpR && NextLevelOpR != L) 196 return false; 197 198 NextLevelOp = NextLevelOpL ? R : L; 199 } else 200 return false; 201 202 // Check that the next levels binary operation exists and matches with the 203 // current one. 204 BinaryOperator *NextLevelBinOp = 0; 205 if (Level + 1 != NumLevels) { 206 if (!(NextLevelBinOp = dyn_cast<BinaryOperator>(NextLevelOp))) 207 return false; 208 else if (NextLevelBinOp->getOpcode() != Opcode) 209 return false; 210 } 211 212 // Shuffle mask for pairwise operation must match. 213 if (matchPairwiseShuffleMask(LS, true, Level)) { 214 if (!matchPairwiseShuffleMask(RS, false, Level)) 215 return false; 216 } else if (matchPairwiseShuffleMask(RS, true, Level)) { 217 if (!matchPairwiseShuffleMask(LS, false, Level)) 218 return false; 219 } else 220 return false; 221 222 if (++Level == NumLevels) 223 return true; 224 225 // Match next level. 226 return matchPairwiseReductionAtLevel(NextLevelBinOp, Level, NumLevels); 227} 228 229static bool matchPairwiseReduction(const ExtractElementInst *ReduxRoot, 230 unsigned &Opcode, Type *&Ty) { 231 if (!EnableReduxCost) 232 return false; 233 234 // Need to extract the first element. 235 ConstantInt *CI = dyn_cast<ConstantInt>(ReduxRoot->getOperand(1)); 236 unsigned Idx = ~0u; 237 if (CI) 238 Idx = CI->getZExtValue(); 239 if (Idx != 0) 240 return false; 241 242 BinaryOperator *RdxStart = dyn_cast<BinaryOperator>(ReduxRoot->getOperand(0)); 243 if (!RdxStart) 244 return false; 245 246 Type *VecTy = ReduxRoot->getOperand(0)->getType(); 247 unsigned NumVecElems = VecTy->getVectorNumElements(); 248 if (!isPowerOf2_32(NumVecElems)) 249 return false; 250 251 // We look for a sequence of shuffle,shuffle,add triples like the following 252 // that builds a pairwise reduction tree. 253 // 254 // (X0, X1, X2, X3) 255 // (X0 + X1, X2 + X3, undef, undef) 256 // ((X0 + X1) + (X2 + X3), undef, undef, undef) 257 // 258 // %rdx.shuf.0.0 = shufflevector <4 x float> %rdx, <4 x float> undef, 259 // <4 x i32> <i32 0, i32 2 , i32 undef, i32 undef> 260 // %rdx.shuf.0.1 = shufflevector <4 x float> %rdx, <4 x float> undef, 261 // <4 x i32> <i32 1, i32 3, i32 undef, i32 undef> 262 // %bin.rdx.0 = fadd <4 x float> %rdx.shuf.0.0, %rdx.shuf.0.1 263 // %rdx.shuf.1.0 = shufflevector <4 x float> %bin.rdx.0, <4 x float> undef, 264 // <4 x i32> <i32 0, i32 undef, i32 undef, i32 undef> 265 // %rdx.shuf.1.1 = shufflevector <4 x float> %bin.rdx.0, <4 x float> undef, 266 // <4 x i32> <i32 1, i32 undef, i32 undef, i32 undef> 267 // %bin.rdx8 = fadd <4 x float> %rdx.shuf.1.0, %rdx.shuf.1.1 268 // %r = extractelement <4 x float> %bin.rdx8, i32 0 269 if (!matchPairwiseReductionAtLevel(RdxStart, 0, Log2_32(NumVecElems))) 270 return false; 271 272 Opcode = RdxStart->getOpcode(); 273 Ty = VecTy; 274 275 return true; 276} 277 278static std::pair<Value *, ShuffleVectorInst *> 279getShuffleAndOtherOprd(BinaryOperator *B) { 280 281 Value *L = B->getOperand(0); 282 Value *R = B->getOperand(1); 283 ShuffleVectorInst *S = 0; 284 285 if ((S = dyn_cast<ShuffleVectorInst>(L))) 286 return std::make_pair(R, S); 287 288 S = dyn_cast<ShuffleVectorInst>(R); 289 return std::make_pair(L, S); 290} 291 292static bool matchVectorSplittingReduction(const ExtractElementInst *ReduxRoot, 293 unsigned &Opcode, Type *&Ty) { 294 if (!EnableReduxCost) 295 return false; 296 297 // Need to extract the first element. 298 ConstantInt *CI = dyn_cast<ConstantInt>(ReduxRoot->getOperand(1)); 299 unsigned Idx = ~0u; 300 if (CI) 301 Idx = CI->getZExtValue(); 302 if (Idx != 0) 303 return false; 304 305 BinaryOperator *RdxStart = dyn_cast<BinaryOperator>(ReduxRoot->getOperand(0)); 306 if (!RdxStart) 307 return false; 308 unsigned RdxOpcode = RdxStart->getOpcode(); 309 310 Type *VecTy = ReduxRoot->getOperand(0)->getType(); 311 unsigned NumVecElems = VecTy->getVectorNumElements(); 312 if (!isPowerOf2_32(NumVecElems)) 313 return false; 314 315 // We look for a sequence of shuffles and adds like the following matching one 316 // fadd, shuffle vector pair at a time. 317 // 318 // %rdx.shuf = shufflevector <4 x float> %rdx, <4 x float> undef, 319 // <4 x i32> <i32 2, i32 3, i32 undef, i32 undef> 320 // %bin.rdx = fadd <4 x float> %rdx, %rdx.shuf 321 // %rdx.shuf7 = shufflevector <4 x float> %bin.rdx, <4 x float> undef, 322 // <4 x i32> <i32 1, i32 undef, i32 undef, i32 undef> 323 // %bin.rdx8 = fadd <4 x float> %bin.rdx, %rdx.shuf7 324 // %r = extractelement <4 x float> %bin.rdx8, i32 0 325 326 unsigned MaskStart = 1; 327 Value *RdxOp = RdxStart; 328 SmallVector<int, 32> ShuffleMask(NumVecElems, 0); 329 unsigned NumVecElemsRemain = NumVecElems; 330 while (NumVecElemsRemain - 1) { 331 // Check for the right reduction operation. 332 BinaryOperator *BinOp; 333 if (!(BinOp = dyn_cast<BinaryOperator>(RdxOp))) 334 return false; 335 if (BinOp->getOpcode() != RdxOpcode) 336 return false; 337 338 Value *NextRdxOp; 339 ShuffleVectorInst *Shuffle; 340 tie(NextRdxOp, Shuffle) = getShuffleAndOtherOprd(BinOp); 341 342 // Check the current reduction operation and the shuffle use the same value. 343 if (Shuffle == 0) 344 return false; 345 if (Shuffle->getOperand(0) != NextRdxOp) 346 return false; 347 348 // Check that shuffle masks matches. 349 for (unsigned j = 0; j != MaskStart; ++j) 350 ShuffleMask[j] = MaskStart + j; 351 // Fill the rest of the mask with -1 for undef. 352 std::fill(&ShuffleMask[MaskStart], ShuffleMask.end(), -1); 353 354 SmallVector<int, 16> Mask = Shuffle->getShuffleMask(); 355 if (!matchMask(ShuffleMask, Mask)) 356 return false; 357 358 RdxOp = NextRdxOp; 359 NumVecElemsRemain /= 2; 360 MaskStart *= 2; 361 } 362 363 Opcode = RdxOpcode; 364 Ty = VecTy; 365 return true; 366} 367 368unsigned CostModelAnalysis::getInstructionCost(const Instruction *I) const { 369 if (!TTI) 370 return -1; 371 372 switch (I->getOpcode()) { 373 case Instruction::GetElementPtr:{ 374 Type *ValTy = I->getOperand(0)->getType()->getPointerElementType(); 375 return TTI->getAddressComputationCost(ValTy); 376 } 377 378 case Instruction::Ret: 379 case Instruction::PHI: 380 case Instruction::Br: { 381 return TTI->getCFInstrCost(I->getOpcode()); 382 } 383 case Instruction::Add: 384 case Instruction::FAdd: 385 case Instruction::Sub: 386 case Instruction::FSub: 387 case Instruction::Mul: 388 case Instruction::FMul: 389 case Instruction::UDiv: 390 case Instruction::SDiv: 391 case Instruction::FDiv: 392 case Instruction::URem: 393 case Instruction::SRem: 394 case Instruction::FRem: 395 case Instruction::Shl: 396 case Instruction::LShr: 397 case Instruction::AShr: 398 case Instruction::And: 399 case Instruction::Or: 400 case Instruction::Xor: { 401 TargetTransformInfo::OperandValueKind Op1VK = 402 getOperandInfo(I->getOperand(0)); 403 TargetTransformInfo::OperandValueKind Op2VK = 404 getOperandInfo(I->getOperand(1)); 405 return TTI->getArithmeticInstrCost(I->getOpcode(), I->getType(), Op1VK, 406 Op2VK); 407 } 408 case Instruction::Select: { 409 const SelectInst *SI = cast<SelectInst>(I); 410 Type *CondTy = SI->getCondition()->getType(); 411 return TTI->getCmpSelInstrCost(I->getOpcode(), I->getType(), CondTy); 412 } 413 case Instruction::ICmp: 414 case Instruction::FCmp: { 415 Type *ValTy = I->getOperand(0)->getType(); 416 return TTI->getCmpSelInstrCost(I->getOpcode(), ValTy); 417 } 418 case Instruction::Store: { 419 const StoreInst *SI = cast<StoreInst>(I); 420 Type *ValTy = SI->getValueOperand()->getType(); 421 return TTI->getMemoryOpCost(I->getOpcode(), ValTy, 422 SI->getAlignment(), 423 SI->getPointerAddressSpace()); 424 } 425 case Instruction::Load: { 426 const LoadInst *LI = cast<LoadInst>(I); 427 return TTI->getMemoryOpCost(I->getOpcode(), I->getType(), 428 LI->getAlignment(), 429 LI->getPointerAddressSpace()); 430 } 431 case Instruction::ZExt: 432 case Instruction::SExt: 433 case Instruction::FPToUI: 434 case Instruction::FPToSI: 435 case Instruction::FPExt: 436 case Instruction::PtrToInt: 437 case Instruction::IntToPtr: 438 case Instruction::SIToFP: 439 case Instruction::UIToFP: 440 case Instruction::Trunc: 441 case Instruction::FPTrunc: 442 case Instruction::BitCast: { 443 Type *SrcTy = I->getOperand(0)->getType(); 444 return TTI->getCastInstrCost(I->getOpcode(), I->getType(), SrcTy); 445 } 446 case Instruction::ExtractElement: { 447 const ExtractElementInst * EEI = cast<ExtractElementInst>(I); 448 ConstantInt *CI = dyn_cast<ConstantInt>(I->getOperand(1)); 449 unsigned Idx = -1; 450 if (CI) 451 Idx = CI->getZExtValue(); 452 453 // Try to match a reduction sequence (series of shufflevector and vector 454 // adds followed by a extractelement). 455 unsigned ReduxOpCode; 456 Type *ReduxType; 457 458 if (matchVectorSplittingReduction(EEI, ReduxOpCode, ReduxType)) 459 return TTI->getReductionCost(ReduxOpCode, ReduxType, false); 460 else if (matchPairwiseReduction(EEI, ReduxOpCode, ReduxType)) 461 return TTI->getReductionCost(ReduxOpCode, ReduxType, true); 462 463 return TTI->getVectorInstrCost(I->getOpcode(), 464 EEI->getOperand(0)->getType(), Idx); 465 } 466 case Instruction::InsertElement: { 467 const InsertElementInst * IE = cast<InsertElementInst>(I); 468 ConstantInt *CI = dyn_cast<ConstantInt>(IE->getOperand(2)); 469 unsigned Idx = -1; 470 if (CI) 471 Idx = CI->getZExtValue(); 472 return TTI->getVectorInstrCost(I->getOpcode(), 473 IE->getType(), Idx); 474 } 475 case Instruction::ShuffleVector: { 476 const ShuffleVectorInst *Shuffle = cast<ShuffleVectorInst>(I); 477 Type *VecTypOp0 = Shuffle->getOperand(0)->getType(); 478 unsigned NumVecElems = VecTypOp0->getVectorNumElements(); 479 SmallVector<int, 16> Mask = Shuffle->getShuffleMask(); 480 481 if (NumVecElems == Mask.size() && isReverseVectorMask(Mask)) 482 return TTI->getShuffleCost(TargetTransformInfo::SK_Reverse, VecTypOp0, 0, 483 0); 484 return -1; 485 } 486 case Instruction::Call: 487 if (const IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) { 488 SmallVector<Type*, 4> Tys; 489 for (unsigned J = 0, JE = II->getNumArgOperands(); J != JE; ++J) 490 Tys.push_back(II->getArgOperand(J)->getType()); 491 492 return TTI->getIntrinsicInstrCost(II->getIntrinsicID(), II->getType(), 493 Tys); 494 } 495 return -1; 496 default: 497 // We don't have any information on this instruction. 498 return -1; 499 } 500} 501 502void CostModelAnalysis::print(raw_ostream &OS, const Module*) const { 503 if (!F) 504 return; 505 506 for (Function::iterator B = F->begin(), BE = F->end(); B != BE; ++B) { 507 for (BasicBlock::iterator it = B->begin(), e = B->end(); it != e; ++it) { 508 Instruction *Inst = it; 509 unsigned Cost = getInstructionCost(Inst); 510 if (Cost != (unsigned)-1) 511 OS << "Cost Model: Found an estimated cost of " << Cost; 512 else 513 OS << "Cost Model: Unknown cost"; 514 515 OS << " for instruction: "<< *Inst << "\n"; 516 } 517 } 518} 519