diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp index 85c8d3996bba51..0b16a8b7676923 100644 --- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp +++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp @@ -684,10 +684,10 @@ bool VectorCombine::foldInsExtFNeg(Instruction &I) { /// destination type followed by shuffle. This can enable further transforms by /// moving bitcasts or shuffles together. bool VectorCombine::foldBitcastShuffle(Instruction &I) { - Value *V; + Value *V0; ArrayRef Mask; - if (!match(&I, m_BitCast( - m_OneUse(m_Shuffle(m_Value(V), m_Undef(), m_Mask(Mask)))))) + if (!match(&I, m_BitCast(m_OneUse( + m_Shuffle(m_Value(V0), m_Undef(), m_Mask(Mask)))))) return false; // 1) Do not fold bitcast shuffle for scalable type. First, shuffle cost for @@ -696,7 +696,7 @@ bool VectorCombine::foldBitcastShuffle(Instruction &I) { // 2) Disallow non-vector casts. // TODO: We could allow any shuffle. auto *DestTy = dyn_cast(I.getType()); - auto *SrcTy = dyn_cast(V->getType()); + auto *SrcTy = dyn_cast(V0->getType()); if (!DestTy || !SrcTy) return false; @@ -724,20 +724,31 @@ bool VectorCombine::foldBitcastShuffle(Instruction &I) { // Bitcast the shuffle src - keep its original width but using the destination // scalar type. unsigned NumSrcElts = SrcTy->getPrimitiveSizeInBits() / DestEltSize; - auto *ShuffleTy = FixedVectorType::get(DestTy->getScalarType(), NumSrcElts); - - // The new shuffle must not cost more than the old shuffle. The bitcast is - // moved ahead of the shuffle, so assume that it has the same cost as before. - InstructionCost DestCost = TTI.getShuffleCost( - TargetTransformInfo::SK_PermuteSingleSrc, ShuffleTy, NewMask); + auto *NewShuffleTy = + FixedVectorType::get(DestTy->getScalarType(), NumSrcElts); + auto *OldShuffleTy = + FixedVectorType::get(SrcTy->getScalarType(), Mask.size()); + + // The new shuffle must not cost more than the old shuffle. + TargetTransformInfo::TargetCostKind CK = + TargetTransformInfo::TCK_RecipThroughput; + TargetTransformInfo::ShuffleKind SK = + TargetTransformInfo::SK_PermuteSingleSrc; + + InstructionCost DestCost = + TTI.getShuffleCost(SK, NewShuffleTy, NewMask, CK) + + TTI.getCastInstrCost(Instruction::BitCast, NewShuffleTy, SrcTy, + TargetTransformInfo::CastContextHint::None, CK); InstructionCost SrcCost = - TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, SrcTy, Mask); + TTI.getShuffleCost(SK, SrcTy, Mask, CK) + + TTI.getCastInstrCost(Instruction::BitCast, DestTy, OldShuffleTy, + TargetTransformInfo::CastContextHint::None, CK); if (DestCost > SrcCost || !DestCost.isValid()) return false; - // bitcast (shuf V, MaskC) --> shuf (bitcast V), MaskC' + // bitcast (shuf V0, MaskC) --> shuf (bitcast V0), MaskC' ++NumShufOfBitcast; - Value *CastV = Builder.CreateBitCast(V, ShuffleTy); + Value *CastV = Builder.CreateBitCast(V0, NewShuffleTy); Value *Shuf = Builder.CreateShuffleVector(CastV, NewMask); replaceValue(I, *Shuf); return true;