Skip to content

Commit

Permalink
[VectorCombine] foldBitcastShuffle - include the cost of bitcasts in …
Browse files Browse the repository at this point in the history
…the comparison

This makes no real difference currently as we only fold unary shuffles, but I'm hoping to handle binary shuffles in a future patch.
  • Loading branch information
RKSimon committed Mar 20, 2024
1 parent 6086937 commit fe2119a
Showing 1 changed file with 24 additions and 13 deletions.
37 changes: 24 additions & 13 deletions llvm/lib/Transforms/Vectorize/VectorCombine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -684,10 +684,10 @@ bool VectorCombine::foldInsExtFNeg(Instruction &I) {
/// destination type followed by shuffle. This can enable further transforms by
/// moving bitcasts or shuffles together.
bool VectorCombine::foldBitcastShuffle(Instruction &I) {
Value *V;
Value *V0;
ArrayRef<int> Mask;
if (!match(&I, m_BitCast(
m_OneUse(m_Shuffle(m_Value(V), m_Undef(), m_Mask(Mask))))))
if (!match(&I, m_BitCast(m_OneUse(
m_Shuffle(m_Value(V0), m_Undef(), m_Mask(Mask))))))
return false;

// 1) Do not fold bitcast shuffle for scalable type. First, shuffle cost for
Expand All @@ -696,7 +696,7 @@ bool VectorCombine::foldBitcastShuffle(Instruction &I) {
// 2) Disallow non-vector casts.
// TODO: We could allow any shuffle.
auto *DestTy = dyn_cast<FixedVectorType>(I.getType());
auto *SrcTy = dyn_cast<FixedVectorType>(V->getType());
auto *SrcTy = dyn_cast<FixedVectorType>(V0->getType());
if (!DestTy || !SrcTy)
return false;

Expand Down Expand Up @@ -724,20 +724,31 @@ bool VectorCombine::foldBitcastShuffle(Instruction &I) {
// Bitcast the shuffle src - keep its original width but using the destination
// scalar type.
unsigned NumSrcElts = SrcTy->getPrimitiveSizeInBits() / DestEltSize;
auto *ShuffleTy = FixedVectorType::get(DestTy->getScalarType(), NumSrcElts);

// The new shuffle must not cost more than the old shuffle. The bitcast is
// moved ahead of the shuffle, so assume that it has the same cost as before.
InstructionCost DestCost = TTI.getShuffleCost(
TargetTransformInfo::SK_PermuteSingleSrc, ShuffleTy, NewMask);
auto *NewShuffleTy =
FixedVectorType::get(DestTy->getScalarType(), NumSrcElts);
auto *OldShuffleTy =
FixedVectorType::get(SrcTy->getScalarType(), Mask.size());

// The new shuffle must not cost more than the old shuffle.
TargetTransformInfo::TargetCostKind CK =
TargetTransformInfo::TCK_RecipThroughput;
TargetTransformInfo::ShuffleKind SK =
TargetTransformInfo::SK_PermuteSingleSrc;

InstructionCost DestCost =
TTI.getShuffleCost(SK, NewShuffleTy, NewMask, CK) +
TTI.getCastInstrCost(Instruction::BitCast, NewShuffleTy, SrcTy,
TargetTransformInfo::CastContextHint::None, CK);
InstructionCost SrcCost =
TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, SrcTy, Mask);
TTI.getShuffleCost(SK, SrcTy, Mask, CK) +
TTI.getCastInstrCost(Instruction::BitCast, DestTy, OldShuffleTy,
TargetTransformInfo::CastContextHint::None, CK);
if (DestCost > SrcCost || !DestCost.isValid())
return false;

// bitcast (shuf V, MaskC) --> shuf (bitcast V), MaskC'
// bitcast (shuf V0, MaskC) --> shuf (bitcast V0), MaskC'
++NumShufOfBitcast;
Value *CastV = Builder.CreateBitCast(V, ShuffleTy);
Value *CastV = Builder.CreateBitCast(V0, NewShuffleTy);
Value *Shuf = Builder.CreateShuffleVector(CastV, NewMask);
replaceValue(I, *Shuf);
return true;
Expand Down

0 comments on commit fe2119a

Please sign in to comment.