From f6b825f51ec8a67c4ace43aaacc27bfd4a78f706 Mon Sep 17 00:00:00 2001 From: Amara Emerson Date: Thu, 7 Mar 2024 23:15:58 -0800 Subject: [PATCH 1/4] Revert "Revert "[AArch64][GlobalISel] Fix incorrect selection of monotonic s32->s64 anyext load."" Attempt 2. The first one was trying to call isa<> on an MI reference that was free'd. This reverts commit ee24409c40ff35c3221892d9723331c233ca9f0e. --- .../GISel/AArch64InstructionSelector.cpp | 9 ++--- .../GlobalISel/select-atomic-load-store.mir | 33 ++++++++++++++++--- 2 files changed, 34 insertions(+), 8 deletions(-) diff --git a/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp b/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp index 6652883792391b..0f3c3cb96e6ce3 100644 --- a/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp +++ b/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp @@ -2997,13 +2997,14 @@ bool AArch64InstructionSelector::select(MachineInstr &I) { } } - if (IsZExtLoad) { - // The zextload from a smaller type to i32 should be handled by the + if (IsZExtLoad || (Opcode == TargetOpcode::G_LOAD && + ValTy == LLT::scalar(64) && MemSizeInBits == 32)) { + // The any/zextload from a smaller type to i32 should be handled by the // importer. if (MRI.getType(LoadStore->getOperand(0).getReg()).getSizeInBits() != 64) return false; - // If we have a ZEXTLOAD then change the load's type to be a narrower reg - // and zero_extend with SUBREG_TO_REG. + // If we have an extending load then change the load's type to be a + // narrower reg and zero_extend with SUBREG_TO_REG. Register LdReg = MRI.createVirtualRegister(&AArch64::GPR32RegClass); Register DstReg = LoadStore->getOperand(0).getReg(); LoadStore->getOperand(0).setReg(LdReg); diff --git a/llvm/test/CodeGen/AArch64/GlobalISel/select-atomic-load-store.mir b/llvm/test/CodeGen/AArch64/GlobalISel/select-atomic-load-store.mir index 5787f914b965d3..6b4bbb85b2ec44 100644 --- a/llvm/test/CodeGen/AArch64/GlobalISel/select-atomic-load-store.mir +++ b/llvm/test/CodeGen/AArch64/GlobalISel/select-atomic-load-store.mir @@ -9,6 +9,11 @@ ret i8 %v } + define i32 @anyext_load_monotonic_i32() { + %v = load atomic i32, ptr null monotonic, align 4 + ret i32 %v + } + ... --- name: load_acq_i8 @@ -25,13 +30,33 @@ body: | ; CHECK-LABEL: name: load_acq_i8 ; CHECK: liveins: $x0 - ; CHECK: [[COPY:%[0-9]+]]:gpr64sp = COPY $x0 - ; CHECK: [[LDARB:%[0-9]+]]:gpr32 = LDARB [[COPY]] :: (load acquire (s8) from %ir.ptr, align 8) - ; CHECK: $w0 = COPY [[LDARB]] - ; CHECK: RET_ReallyLR implicit $w0 + ; CHECK-NEXT: {{ $}} + ; CHECK-NEXT: [[COPY:%[0-9]+]]:gpr64sp = COPY $x0 + ; CHECK-NEXT: [[LDARB:%[0-9]+]]:gpr32 = LDARB [[COPY]] :: (load acquire (s8) from %ir.ptr, align 8) + ; CHECK-NEXT: $w0 = COPY [[LDARB]] + ; CHECK-NEXT: RET_ReallyLR implicit $w0 %0:gpr(p0) = COPY $x0 %2:gpr(s32) = G_LOAD %0(p0) :: (load acquire (s8) from %ir.ptr, align 8) $w0 = COPY %2(s32) RET_ReallyLR implicit $w0 ... +--- +name: anyext_load_monotonic_i32 +legalized: true +regBankSelected: true +tracksRegLiveness: true +body: | + bb.1: + ; CHECK-LABEL: name: anyext_load_monotonic_i32 + ; CHECK: [[COPY:%[0-9]+]]:gpr64common = COPY $xzr + ; CHECK-NEXT: [[LDRWui:%[0-9]+]]:gpr32 = LDRWui [[COPY]], 0 :: (load monotonic (s32) from `ptr null`) + ; CHECK-NEXT: %ld:gpr64all = SUBREG_TO_REG 0, [[LDRWui]], %subreg.sub_32 + ; CHECK-NEXT: $x0 = COPY %ld + ; CHECK-NEXT: RET_ReallyLR implicit $x0 + %1:gpr(p0) = G_CONSTANT i64 0 + %ld:gpr(s64) = G_LOAD %1(p0) :: (load monotonic (s32) from `ptr null`) + $x0 = COPY %ld(s64) + RET_ReallyLR implicit $x0 + +... From b6a340023d383d1e77cb8d91d92c096f791fa8c0 Mon Sep 17 00:00:00 2001 From: Vlad Serebrennikov Date: Fri, 8 Mar 2024 11:31:00 +0400 Subject: [PATCH 2/4] [clang] Respect field alignment in layout compatibility of structs (#84313) This patch implements [CWG2586](https://cplusplus.github.io/CWG/issues/2583.html) "Common initial sequence should consider over-alignment". Note that alignment of union members doesn't have to match, as layout compatibility of unions is not defined in terms of common initial sequence (http://eel.is/c++draft/class.mem.general#25). --- clang/docs/ReleaseNotes.rst | 4 ++++ clang/lib/Sema/SemaChecking.cpp | 23 +++++++++++++++++++++-- clang/test/CXX/drs/dr25xx.cpp | 26 ++++++++++++++++++++++++++ clang/test/SemaCXX/type-traits.cpp | 13 ++++++++++++- clang/www/cxx_dr_status.html | 2 +- 5 files changed, 64 insertions(+), 4 deletions(-) diff --git a/clang/docs/ReleaseNotes.rst b/clang/docs/ReleaseNotes.rst index fa23c215790f11..fe7bbe437831ed 100644 --- a/clang/docs/ReleaseNotes.rst +++ b/clang/docs/ReleaseNotes.rst @@ -115,6 +115,10 @@ Resolutions to C++ Defect Reports of two types. (`CWG1719: Layout compatibility and cv-qualification revisited `_). +- Alignment of members is now respected when evaluating layout compatibility + of structs. + (`CWG2583: Common initial sequence should consider over-alignment `_). + - ``[[no_unique_address]]`` is now respected when evaluating layout compatibility of two types. (`CWG2759: [[no_unique_address] and common initial sequence `_). diff --git a/clang/lib/Sema/SemaChecking.cpp b/clang/lib/Sema/SemaChecking.cpp index 3597f93a017136..b34b8df0020137 100644 --- a/clang/lib/Sema/SemaChecking.cpp +++ b/clang/lib/Sema/SemaChecking.cpp @@ -19184,8 +19184,22 @@ static bool isLayoutCompatible(ASTContext &C, EnumDecl *ED1, EnumDecl *ED2) { } /// Check if two fields are layout-compatible. +/// Can be used on union members, which are exempt from alignment requirement +/// of common initial sequence. static bool isLayoutCompatible(ASTContext &C, FieldDecl *Field1, - FieldDecl *Field2) { + FieldDecl *Field2, + bool AreUnionMembers = false) { + const Type *Field1Parent = Field1->getParent()->getTypeForDecl(); + const Type *Field2Parent = Field2->getParent()->getTypeForDecl(); + assert(((Field1Parent->isStructureOrClassType() && + Field2Parent->isStructureOrClassType()) || + (Field1Parent->isUnionType() && Field2Parent->isUnionType())) && + "Can't evaluate layout compatibility between a struct field and a " + "union field."); + assert(((!AreUnionMembers && Field1Parent->isStructureOrClassType()) || + (AreUnionMembers && Field1Parent->isUnionType())) && + "AreUnionMembers should be 'true' for union fields (only)."); + if (!isLayoutCompatible(C, Field1->getType(), Field2->getType())) return false; @@ -19204,6 +19218,11 @@ static bool isLayoutCompatible(ASTContext &C, FieldDecl *Field1, if (Field1->hasAttr() || Field2->hasAttr()) return false; + + if (!AreUnionMembers && + Field1->getMaxAlignment() != Field2->getMaxAlignment()) + return false; + return true; } @@ -19265,7 +19284,7 @@ static bool isLayoutCompatibleUnion(ASTContext &C, RecordDecl *RD1, E = UnmatchedFields.end(); for ( ; I != E; ++I) { - if (isLayoutCompatible(C, Field1, *I)) { + if (isLayoutCompatible(C, Field1, *I, /*IsUnionMember=*/true)) { bool Result = UnmatchedFields.erase(*I); (void) Result; assert(Result); diff --git a/clang/test/CXX/drs/dr25xx.cpp b/clang/test/CXX/drs/dr25xx.cpp index 9fc7cf59485caa..46532486e50e53 100644 --- a/clang/test/CXX/drs/dr25xx.cpp +++ b/clang/test/CXX/drs/dr25xx.cpp @@ -211,6 +211,32 @@ namespace dr2565 { // dr2565: 16 open 2023-06-07 #endif } +namespace dr2583 { // dr2583: 19 +#if __cplusplus >= 201103L +struct A { + int i; + char c; +}; + +struct B { + int i; + alignas(8) char c; +}; + +union U { + A a; + B b; +}; + +union V { + A a; + alignas(64) B b; +}; + +static_assert(!__is_layout_compatible(A, B), ""); +static_assert(__is_layout_compatible(U, V), ""); +#endif +} // namespace dr2583 namespace dr2598 { // dr2598: 18 #if __cplusplus >= 201103L diff --git a/clang/test/SemaCXX/type-traits.cpp b/clang/test/SemaCXX/type-traits.cpp index 23c339ebdf0826..831de2589dcb9e 100644 --- a/clang/test/SemaCXX/type-traits.cpp +++ b/clang/test/SemaCXX/type-traits.cpp @@ -1681,6 +1681,16 @@ union UnionLayout3 { [[no_unique_address]] CEmptyStruct d; }; +union UnionNoOveralignedMembers { + int a; + double b; +}; + +union UnionWithOveralignedMembers { + int a; + alignas(16) double b; +}; + struct StructWithAnonUnion { union { int a; @@ -1771,7 +1781,8 @@ void is_layout_compatible(int n) static_assert(__is_layout_compatible(CStruct, CStructNoUniqueAddress) != bool(__has_cpp_attribute(no_unique_address)), ""); static_assert(__is_layout_compatible(CStructNoUniqueAddress, CStructNoUniqueAddress2) != bool(__has_cpp_attribute(no_unique_address)), ""); static_assert(__is_layout_compatible(CStruct, CStructAlignment), ""); - static_assert(__is_layout_compatible(CStruct, CStructAlignedMembers), ""); // FIXME: alignment of members impact common initial sequence + static_assert(!__is_layout_compatible(CStruct, CStructAlignedMembers), ""); + static_assert(__is_layout_compatible(UnionNoOveralignedMembers, UnionWithOveralignedMembers), ""); static_assert(__is_layout_compatible(CStructWithBitfelds, CStructWithBitfelds), ""); static_assert(__is_layout_compatible(CStructWithBitfelds, CStructWithBitfelds2), ""); static_assert(!__is_layout_compatible(CStructWithBitfelds, CStructWithBitfelds3), ""); diff --git a/clang/www/cxx_dr_status.html b/clang/www/cxx_dr_status.html index 503472a2cae4eb..c20a5d021e9d95 100755 --- a/clang/www/cxx_dr_status.html +++ b/clang/www/cxx_dr_status.html @@ -15306,7 +15306,7 @@

C++ defect report implementation status

2583 C++23 Common initial sequence should consider over-alignment - Unknown + Clang 19 2584 From fb1be9b33ca3ed1b7ea54b15bd77fd868726b57c Mon Sep 17 00:00:00 2001 From: Vyacheslav Levytskyy Date: Fri, 8 Mar 2024 08:31:56 +0100 Subject: [PATCH 3/4] [SPIR-V] Insert a bitcast before load/store instruction to keep SPIR-V code valid (#84069) This PR introduces a step after instruction selection where instructions can be traversed from the perspective of their validity from the specification point of view. The PR adds also a way to correct load/store when there is a type mismatch contradicting the specification -- an additional bitcast is inserted to keep types consistent. Correspondent test cases are added and existing test cases are corrected. This PR helps to successfully validate with the `spirv-val` tool (https://github.com/KhronosGroup/SPIRV-Tools) some output that previously led to validation errors and crashes of back translation from SPIRV to LLVM IR from the side of SPIRV Translator project (https://github.com/KhronosGroup/SPIRV-LLVM-Translator). The added step of bringing instructions to required by the specification type correspondence can be (should be and will be) extended beyond load/store instructions to ensure validity rules of other SPIRV instructions related to type inference. --- llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp | 4 +- llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp | 93 ++++++++++--------- llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h | 1 + llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp | 80 ++++++++++++++++ llvm/lib/Target/SPIRV/SPIRVISelLowering.h | 12 ++- .../SPIRV/constant/global-constants.ll | 1 + .../SPIRV/pointers/bitcast-fix-load.ll | 21 +++++ .../SPIRV/pointers/bitcast-fix-store.ll | 31 +++++++ llvm/test/CodeGen/SPIRV/spirv-load-store.ll | 11 ++- 9 files changed, 203 insertions(+), 51 deletions(-) create mode 100644 llvm/test/CodeGen/SPIRV/pointers/bitcast-fix-load.ll create mode 100644 llvm/test/CodeGen/SPIRV/pointers/bitcast-fix-store.ll diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp index 5c432d68273234..575e903d05bb97 100644 --- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp @@ -20,6 +20,7 @@ #include "llvm/IR/InstIterator.h" #include "llvm/IR/InstVisitor.h" #include "llvm/IR/IntrinsicsSPIRV.h" +#include "llvm/IR/TypedPointerType.h" #include @@ -446,7 +447,8 @@ void SPIRVEmitIntrinsics::insertPtrCastOrAssignTypeInstr(Instruction *I, for (unsigned OpIdx = 0; OpIdx < CI->arg_size(); OpIdx++) { Value *ArgOperand = CI->getArgOperand(OpIdx); - if (!isa(ArgOperand->getType())) + if (!isa(ArgOperand->getType()) && + !isa(ArgOperand->getType())) continue; // Constants (nulls/undefs) are handled in insertAssignPtrTypeIntrs() diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp index e88298f52fbe18..8556581996fede 100644 --- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp @@ -20,6 +20,7 @@ #include "SPIRVSubtarget.h" #include "SPIRVTargetMachine.h" #include "SPIRVUtils.h" +#include "llvm/IR/TypedPointerType.h" using namespace llvm; SPIRVGlobalRegistry::SPIRVGlobalRegistry(unsigned PointerSize) @@ -420,9 +421,10 @@ Register SPIRVGlobalRegistry::getOrCreateConstNullPtr(MachineIRBuilder &MIRBuilder, SPIRVType *SpvType) { const Type *LLVMTy = getTypeForSPIRVType(SpvType); - const PointerType *LLVMPtrTy = cast(LLVMTy); + const TypedPointerType *LLVMPtrTy = cast(LLVMTy); // Find a constant in DT or build a new one. - Constant *CP = ConstantPointerNull::get(const_cast(LLVMPtrTy)); + Constant *CP = ConstantPointerNull::get(PointerType::get( + LLVMPtrTy->getElementType(), LLVMPtrTy->getAddressSpace())); Register Res = DT.find(CP, CurMF); if (!Res.isValid()) { LLT LLTy = LLT::pointer(LLVMPtrTy->getAddressSpace(), PointerSize); @@ -517,6 +519,13 @@ Register SPIRVGlobalRegistry::buildGlobalVariable( LLT RegLLTy = LLT::pointer(MRI->getType(ResVReg).getAddressSpace(), 32); MRI->setType(Reg, RegLLTy); assignSPIRVTypeToVReg(BaseType, Reg, MIRBuilder.getMF()); + } else { + // Our knowledge about the type may be updated. + // If that's the case, we need to update a type + // associated with the register. + SPIRVType *DefType = getSPIRVTypeForVReg(ResVReg); + if (!DefType || DefType != BaseType) + assignSPIRVTypeToVReg(BaseType, Reg, MIRBuilder.getMF()); } // If it's a global variable with name, output OpName for it. @@ -705,33 +714,37 @@ SPIRVType *SPIRVGlobalRegistry::createSPIRVType( } return getOpTypeFunction(RetTy, ParamTypes, MIRBuilder); } - if (auto PType = dyn_cast(Ty)) { - SPIRVType *SpvElementType; - // At the moment, all opaque pointers correspond to i8 element type. - // TODO: change the implementation once opaque pointers are supported - // in the SPIR-V specification. - SpvElementType = getOrCreateSPIRVIntegerType(8, MIRBuilder); - // Get access to information about available extensions - const SPIRVSubtarget *ST = - static_cast(&MIRBuilder.getMF().getSubtarget()); - auto SC = addressSpaceToStorageClass(PType->getAddressSpace(), *ST); - // Null pointer means we have a loop in type definitions, make and - // return corresponding OpTypeForwardPointer. - if (SpvElementType == nullptr) { - if (!ForwardPointerTypes.contains(Ty)) - ForwardPointerTypes[PType] = getOpTypeForwardPointer(SC, MIRBuilder); - return ForwardPointerTypes[PType]; - } - // If we have forward pointer associated with this type, use its register - // operand to create OpTypePointer. - if (ForwardPointerTypes.contains(PType)) { - Register Reg = getSPIRVTypeID(ForwardPointerTypes[PType]); - return getOpTypePointer(SC, SpvElementType, MIRBuilder, Reg); - } - - return getOrCreateSPIRVPointerType(SpvElementType, MIRBuilder, SC); + unsigned AddrSpace = 0xFFFF; + if (auto PType = dyn_cast(Ty)) + AddrSpace = PType->getAddressSpace(); + else if (auto PType = dyn_cast(Ty)) + AddrSpace = PType->getAddressSpace(); + else + report_fatal_error("Unable to convert LLVM type to SPIRVType", true); + SPIRVType *SpvElementType; + // At the moment, all opaque pointers correspond to i8 element type. + // TODO: change the implementation once opaque pointers are supported + // in the SPIR-V specification. + SpvElementType = getOrCreateSPIRVIntegerType(8, MIRBuilder); + // Get access to information about available extensions + const SPIRVSubtarget *ST = + static_cast(&MIRBuilder.getMF().getSubtarget()); + auto SC = addressSpaceToStorageClass(AddrSpace, *ST); + // Null pointer means we have a loop in type definitions, make and + // return corresponding OpTypeForwardPointer. + if (SpvElementType == nullptr) { + if (!ForwardPointerTypes.contains(Ty)) + ForwardPointerTypes[Ty] = getOpTypeForwardPointer(SC, MIRBuilder); + return ForwardPointerTypes[Ty]; + } + // If we have forward pointer associated with this type, use its register + // operand to create OpTypePointer. + if (ForwardPointerTypes.contains(Ty)) { + Register Reg = getSPIRVTypeID(ForwardPointerTypes[Ty]); + return getOpTypePointer(SC, SpvElementType, MIRBuilder, Reg); } - llvm_unreachable("Unable to convert LLVM type to SPIRVType"); + + return getOrCreateSPIRVPointerType(SpvElementType, MIRBuilder, SC); } SPIRVType *SPIRVGlobalRegistry::restOfCreateSPIRVType( @@ -1139,11 +1152,13 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVPointerType( SPIRV::StorageClass::StorageClass SC) { const Type *PointerElementType = getTypeForSPIRVType(BaseType); unsigned AddressSpace = storageClassToAddressSpace(SC); - Type *LLVMTy = - PointerType::get(const_cast(PointerElementType), AddressSpace); + Type *LLVMTy = TypedPointerType::get(const_cast(PointerElementType), + AddressSpace); + // check if this type is already available Register Reg = DT.find(PointerElementType, AddressSpace, CurMF); if (Reg.isValid()) return getSPIRVTypeForVReg(Reg); + // create a new type auto MIB = BuildMI(MIRBuilder.getMBB(), MIRBuilder.getInsertPt(), MIRBuilder.getDebugLoc(), MIRBuilder.getTII().get(SPIRV::OpTypePointer)) @@ -1155,22 +1170,10 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVPointerType( } SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVPointerType( - SPIRVType *BaseType, MachineInstr &I, const SPIRVInstrInfo &TII, + SPIRVType *BaseType, MachineInstr &I, const SPIRVInstrInfo &, SPIRV::StorageClass::StorageClass SC) { - const Type *PointerElementType = getTypeForSPIRVType(BaseType); - unsigned AddressSpace = storageClassToAddressSpace(SC); - Type *LLVMTy = - PointerType::get(const_cast(PointerElementType), AddressSpace); - Register Reg = DT.find(PointerElementType, AddressSpace, CurMF); - if (Reg.isValid()) - return getSPIRVTypeForVReg(Reg); - MachineBasicBlock &BB = *I.getParent(); - auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpTypePointer)) - .addDef(createTypeVReg(CurMF->getRegInfo())) - .addImm(static_cast(SC)) - .addUse(getSPIRVTypeID(BaseType)); - DT.add(PointerElementType, AddressSpace, CurMF, getSPIRVTypeID(MIB)); - return finishCreatingSPIRVType(LLVMTy, MIB); + MachineIRBuilder MIRBuilder(I); + return getOrCreateSPIRVPointerType(BaseType, MIRBuilder, SC); } Register SPIRVGlobalRegistry::getOrCreateUndef(MachineInstr &I, diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h index f5a83072c19d76..9c0061d13fd0cf 100644 --- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h +++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h @@ -34,6 +34,7 @@ class SPIRVGlobalRegistry { DenseMap> VRegToTypeMap; + // Map LLVM Type* to SPIRVGeneralDuplicatesTracker DT; DenseMap SPIRVToLLVMType; diff --git a/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp index 33c6aa242969de..61748070fc0fb2 100644 --- a/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp @@ -12,6 +12,13 @@ #include "SPIRVISelLowering.h" #include "SPIRV.h" +#include "SPIRVInstrInfo.h" +#include "SPIRVRegisterBankInfo.h" +#include "SPIRVRegisterInfo.h" +#include "SPIRVSubtarget.h" +#include "SPIRVTargetMachine.h" +#include "llvm/CodeGen/MachineInstrBuilder.h" +#include "llvm/CodeGen/MachineRegisterInfo.h" #include "llvm/IR/IntrinsicsSPIRV.h" #define DEBUG_TYPE "spirv-lower" @@ -74,3 +81,76 @@ bool SPIRVTargetLowering::getTgtMemIntrinsic(IntrinsicInfo &Info, } return false; } + +// Insert a bitcast before the instruction to keep SPIR-V code valid +// when there is a type mismatch between results and operand types. +static void validatePtrTypes(const SPIRVSubtarget &STI, + MachineRegisterInfo *MRI, SPIRVGlobalRegistry &GR, + MachineInstr &I, SPIRVType *ResType, + unsigned OpIdx) { + Register OpReg = I.getOperand(OpIdx).getReg(); + SPIRVType *TypeInst = MRI->getVRegDef(OpReg); + SPIRVType *OpType = GR.getSPIRVTypeForVReg( + TypeInst && TypeInst->getOpcode() == SPIRV::OpFunctionParameter + ? TypeInst->getOperand(1).getReg() + : OpReg); + if (!ResType || !OpType || OpType->getOpcode() != SPIRV::OpTypePointer) + return; + SPIRVType *ElemType = GR.getSPIRVTypeForVReg(OpType->getOperand(2).getReg()); + if (!ElemType || ElemType == ResType) + return; + // There is a type mismatch between results and operand types + // and we insert a bitcast before the instruction to keep SPIR-V code valid + SPIRV::StorageClass::StorageClass SC = + static_cast( + OpType->getOperand(1).getImm()); + MachineInstr *PrevI = I.getPrevNode(); + MachineBasicBlock &MBB = *I.getParent(); + MachineBasicBlock::iterator InsPt = + PrevI ? PrevI->getIterator() : MBB.begin(); + MachineIRBuilder MIB(MBB, InsPt); + SPIRVType *NewPtrType = GR.getOrCreateSPIRVPointerType(ResType, MIB, SC); + if (!GR.isBitcastCompatible(NewPtrType, OpType)) + report_fatal_error( + "insert validation bitcast: incompatible result and operand types"); + Register NewReg = MRI->createGenericVirtualRegister(LLT::scalar(32)); + bool Res = MIB.buildInstr(SPIRV::OpBitcast) + .addDef(NewReg) + .addUse(GR.getSPIRVTypeID(NewPtrType)) + .addUse(OpReg) + .constrainAllUses(*STI.getInstrInfo(), *STI.getRegisterInfo(), + *STI.getRegBankInfo()); + if (!Res) + report_fatal_error("insert validation bitcast: cannot constrain all uses"); + MRI->setRegClass(NewReg, &SPIRV::IDRegClass); + GR.assignSPIRVTypeToVReg(NewPtrType, NewReg, MIB.getMF()); + I.getOperand(OpIdx).setReg(NewReg); +} + +// TODO: the logic of inserting additional bitcast's is to be moved +// to pre-IRTranslation passes eventually +void SPIRVTargetLowering::finalizeLowering(MachineFunction &MF) const { + MachineRegisterInfo *MRI = &MF.getRegInfo(); + SPIRVGlobalRegistry &GR = *STI.getSPIRVGlobalRegistry(); + GR.setCurrentFunc(MF); + for (MachineFunction::iterator I = MF.begin(), E = MF.end(); I != E; ++I) { + MachineBasicBlock *MBB = &*I; + for (MachineBasicBlock::iterator MBBI = MBB->begin(), MBBE = MBB->end(); + MBBI != MBBE;) { + MachineInstr &MI = *MBBI++; + switch (MI.getOpcode()) { + case SPIRV::OpLoad: + // OpLoad , ptr %Op implies that %Op is a pointer to + validatePtrTypes(STI, MRI, GR, MI, + GR.getSPIRVTypeForVReg(MI.getOperand(0).getReg()), 2); + break; + case SPIRV::OpStore: + // OpStore ptr %Op, implies that %Op points to the 's type + validatePtrTypes(STI, MRI, GR, MI, + GR.getSPIRVTypeForVReg(MI.getOperand(1).getReg()), 0); + break; + } + } + } + TargetLowering::finalizeLowering(MF); +} diff --git a/llvm/lib/Target/SPIRV/SPIRVISelLowering.h b/llvm/lib/Target/SPIRV/SPIRVISelLowering.h index d34f802e9d889f..b01571bfc1eeb5 100644 --- a/llvm/lib/Target/SPIRV/SPIRVISelLowering.h +++ b/llvm/lib/Target/SPIRV/SPIRVISelLowering.h @@ -14,16 +14,19 @@ #ifndef LLVM_LIB_TARGET_SPIRV_SPIRVISELLOWERING_H #define LLVM_LIB_TARGET_SPIRV_SPIRVISELLOWERING_H +#include "SPIRVGlobalRegistry.h" #include "llvm/CodeGen/TargetLowering.h" namespace llvm { class SPIRVSubtarget; class SPIRVTargetLowering : public TargetLowering { + const SPIRVSubtarget &STI; + public: explicit SPIRVTargetLowering(const TargetMachine &TM, - const SPIRVSubtarget &STI) - : TargetLowering(TM) {} + const SPIRVSubtarget &ST) + : TargetLowering(TM), STI(ST) {} // Stop IRTranslator breaking up FMA instrs to preserve types information. bool isFMAFasterThanFMulAndFAdd(const MachineFunction &MF, @@ -47,6 +50,11 @@ class SPIRVTargetLowering : public TargetLowering { bool getTgtMemIntrinsic(IntrinsicInfo &Info, const CallInst &I, MachineFunction &MF, unsigned Intrinsic) const override; + + // Call the default implementation and finalize target lowering by inserting + // extra instructions required to preserve validity of SPIR-V code imposed by + // the standard. + void finalizeLowering(MachineFunction &MF) const override; }; } // namespace llvm diff --git a/llvm/test/CodeGen/SPIRV/constant/global-constants.ll b/llvm/test/CodeGen/SPIRV/constant/global-constants.ll index 916c70628d0169..74e28cbe7acb17 100644 --- a/llvm/test/CodeGen/SPIRV/constant/global-constants.ll +++ b/llvm/test/CodeGen/SPIRV/constant/global-constants.ll @@ -1,4 +1,5 @@ ; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s +; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %} @global = addrspace(1) constant i32 1 ; OpenCL global memory @constant = addrspace(2) constant i32 2 ; OpenCL constant memory diff --git a/llvm/test/CodeGen/SPIRV/pointers/bitcast-fix-load.ll b/llvm/test/CodeGen/SPIRV/pointers/bitcast-fix-load.ll new file mode 100644 index 00000000000000..a30d0792e39988 --- /dev/null +++ b/llvm/test/CodeGen/SPIRV/pointers/bitcast-fix-load.ll @@ -0,0 +1,21 @@ +; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s +; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %} + +; CHECK-DAG: %[[#TYLONG:]] = OpTypeInt 32 0 +; CHECK-DAG: %[[#TYSTRUCTLONG:]] = OpTypeStruct %[[#TYLONG]] +; CHECK-DAG: %[[#TYARRAY:]] = OpTypeArray %[[#TYSTRUCTLONG]] %[[#]] +; CHECK-DAG: %[[#TYSTRUCT:]] = OpTypeStruct %[[#TYARRAY]] +; CHECK-DAG: %[[#TYSTRUCTPTR:]] = OpTypePointer Function %[[#TYSTRUCT]] +; CHECK-DAG: %[[#TYLONGPTR:]] = OpTypePointer Function %[[#TYLONG]] +; CHECK: %[[#PTRTOSTRUCT:]] = OpFunctionParameter %[[#TYSTRUCTPTR]] +; CHECK: %[[#PTRTOLONG:]] = OpBitcast %[[#TYLONGPTR]] %[[#PTRTOSTRUCT]] +; CHECK: OpLoad %[[#TYLONG]] %[[#PTRTOLONG]] + +%struct.S = type { i32 } +%struct.__wrapper_class = type { [7 x %struct.S] } + +define spir_kernel void @foo(ptr noundef byval(%struct.__wrapper_class) align 4 %_arg_Arr) { +entry: + %val = load i32, ptr %_arg_Arr + ret void +} diff --git a/llvm/test/CodeGen/SPIRV/pointers/bitcast-fix-store.ll b/llvm/test/CodeGen/SPIRV/pointers/bitcast-fix-store.ll new file mode 100644 index 00000000000000..4701f02ea33af3 --- /dev/null +++ b/llvm/test/CodeGen/SPIRV/pointers/bitcast-fix-store.ll @@ -0,0 +1,31 @@ +; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s +; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %} + +; CHECK-DAG: %[[#TYLONG:]] = OpTypeInt 32 0 +; CHECK-DAG: %[[#TYLONGPTR:]] = OpTypePointer Function %[[#TYLONG]] +; CHECK-DAG: %[[#TYSTRUCT:]] = OpTypeStruct %[[#TYLONG]] +; CHECK-DAG: %[[#CONST:]] = OpConstant %[[#TYLONG]] 3 +; CHECK-DAG: %[[#TYSTRUCTPTR:]] = OpTypePointer Function %[[#TYSTRUCT]] +; CHECK: OpFunction +; CHECK: %[[#ARGPTR1:]] = OpFunctionParameter %[[#TYLONGPTR]] +; CHECK: OpStore %[[#ARGPTR1]] %[[#CONST:]] +; CHECK: OpFunction +; CHECK: %[[#OBJ:]] = OpFunctionParameter %[[#TYSTRUCT]] +; CHECK: %[[#ARGPTR2:]] = OpFunctionParameter %[[#TYLONGPTR]] +; CHECK: %[[#PTRTOSTRUCT:]] = OpBitcast %[[#TYSTRUCTPTR]] %[[#ARGPTR2]] +; CHECK: OpStore %[[#PTRTOSTRUCT]] %[[#OBJ]] + +%struct.S = type { i32 } +%struct.__wrapper_class = type { [7 x %struct.S] } + +define spir_kernel void @foo(%struct.S %arg, ptr %ptr) { +entry: + store %struct.S %arg, ptr %ptr + ret void +} + +define spir_kernel void @bar(ptr %ptr) { +entry: + store i32 3, ptr %ptr + ret void +} diff --git a/llvm/test/CodeGen/SPIRV/spirv-load-store.ll b/llvm/test/CodeGen/SPIRV/spirv-load-store.ll index a82bf0ab2e01f6..9188617312466d 100644 --- a/llvm/test/CodeGen/SPIRV/spirv-load-store.ll +++ b/llvm/test/CodeGen/SPIRV/spirv-load-store.ll @@ -1,9 +1,14 @@ ; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s +; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %} ;; Translate SPIR-V friendly OpLoad and OpStore calls -; CHECK: %[[#CONST:]] = OpConstant %[[#]] 42 -; CHECK: OpStore %[[#PTR:]] %[[#CONST]] Volatile|Aligned 4 -; CHECK: %[[#]] = OpLoad %[[#]] %[[#PTR]] +; CHECK-DAG: %[[#TYLONG:]] = OpTypeInt 32 0 +; CHECK-DAG: %[[#TYFLOAT:]] = OpTypeFloat 64 +; CHECK-DAG: %[[#TYFLOATPTR:]] = OpTypePointer CrossWorkgroup %[[#TYFLOAT]] +; CHECK-DAG: %[[#CONST:]] = OpConstant %[[#TYLONG]] 42 +; CHECK: OpStore %[[#PTRTOLONG:]] %[[#CONST]] Volatile|Aligned 4 +; CHECK: %[[#PTRTOFLOAT:]] = OpBitcast %[[#TYFLOATPTR]] %[[#PTRTOLONG]] +; CHECK: OpLoad %[[#TYFLOAT]] %[[#PTRTOFLOAT]] define weak_odr dso_local spir_kernel void @foo(i32 addrspace(1)* %var) { entry: From df9be017b7828e0a1dbb4f1f507a92266b61e680 Mon Sep 17 00:00:00 2001 From: Marius Brehler Date: Fri, 8 Mar 2024 08:34:56 +0100 Subject: [PATCH 4/4] [mlir][EmitC] Add `unary_{minus,plus}` operators (#84329) This adds operations for the unary minus and the unary plus operator. --- mlir/include/mlir/Dialect/EmitC/IR/EmitC.td | 36 +++++++++++++++++++++ mlir/lib/Target/Cpp/TranslateToCpp.cpp | 17 +++++++++- mlir/test/Dialect/EmitC/ops.mlir | 6 ++++ mlir/test/Target/Cpp/unary_operators.mlir | 12 +++++++ 4 files changed, 70 insertions(+), 1 deletion(-) create mode 100644 mlir/test/Target/Cpp/unary_operators.mlir diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td index db0e2d10960d72..ac1e38a5506da0 100644 --- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td +++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td @@ -908,6 +908,42 @@ def EmitC_SubOp : EmitC_BinaryOp<"sub", [CExpression]> { let hasVerifier = 1; } +def EmitC_UnaryMinusOp : EmitC_UnaryOp<"unary_minus", [CExpression]> { + let summary = "Unary minus operation"; + let description = [{ + With the `unary_minus` operation the unary operator - (minus) can be + applied. + + Example: + + ```mlir + %0 = emitc.unary_plus %arg0 : (i32) -> i32 + ``` + ```c++ + // Code emitted for the operation above. + int32_t v2 = -v1; + ``` + }]; +} + +def EmitC_UnaryPlusOp : EmitC_UnaryOp<"unary_plus", [CExpression]> { + let summary = "Unary plus operation"; + let description = [{ + With the `unary_plus` operation the unary operator + (plus) can be + applied. + + Example: + + ```mlir + %0 = emitc.unary_plus %arg0 : (i32) -> i32 + ``` + ```c++ + // Code emitted for the operation above. + int32_t v2 = +v1; + ``` + }]; +} + def EmitC_VariableOp : EmitC_Op<"variable", []> { let summary = "Variable operation"; let description = [{ diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp index 95513cb0fb2ebc..b99d0ede8bf4ff 100644 --- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp +++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp @@ -104,6 +104,8 @@ static FailureOr getOperatorPrecedence(Operation *operation) { .Case([&](auto op) { return 13; }) .Case([&](auto op) { return 13; }) .Case([&](auto op) { return 12; }) + .Case([&](auto op) { return 15; }) + .Case([&](auto op) { return 15; }) .Default([](auto op) { return op->emitError("unsupported operation"); }); } @@ -652,6 +654,18 @@ static LogicalResult printOperation(CppEmitter &emitter, return printBinaryOperation(emitter, operation, "^"); } +static LogicalResult printOperation(CppEmitter &emitter, + emitc::UnaryPlusOp unaryPlusOp) { + Operation *operation = unaryPlusOp.getOperation(); + return printUnaryOperation(emitter, operation, "+"); +} + +static LogicalResult printOperation(CppEmitter &emitter, + emitc::UnaryMinusOp unaryMinusOp) { + Operation *operation = unaryMinusOp.getOperation(); + return printUnaryOperation(emitter, operation, "-"); +} + static LogicalResult printOperation(CppEmitter &emitter, emitc::CastOp castOp) { raw_ostream &os = emitter.ostream(); Operation &op = *castOp.getOperation(); @@ -1371,7 +1385,8 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) { emitc::ExpressionOp, emitc::ForOp, emitc::FuncOp, emitc::IfOp, emitc::IncludeOp, emitc::LogicalAndOp, emitc::LogicalNotOp, emitc::LogicalOrOp, emitc::MulOp, emitc::RemOp, emitc::ReturnOp, - emitc::SubOp, emitc::VariableOp, emitc::VerbatimOp>( + emitc::SubOp, emitc::UnaryMinusOp, emitc::UnaryPlusOp, + emitc::VariableOp, emitc::VerbatimOp>( [&](auto op) { return printOperation(*this, op); }) // Func ops. .Case( diff --git a/mlir/test/Dialect/EmitC/ops.mlir b/mlir/test/Dialect/EmitC/ops.mlir index f852390f03e298..122b1d9ef1059f 100644 --- a/mlir/test/Dialect/EmitC/ops.mlir +++ b/mlir/test/Dialect/EmitC/ops.mlir @@ -134,6 +134,12 @@ func.func @logical(%arg0: i32, %arg1: i32) { return } +func.func @unary(%arg0: i32) { + %0 = emitc.unary_minus %arg0 : (i32) -> i32 + %1 = emitc.unary_plus %arg0 : (i32) -> i32 + return +} + func.func @test_if(%arg0: i1, %arg1: f32) { emitc.if %arg0 { %0 = emitc.call_opaque "func_const"(%arg1) : (f32) -> i32 diff --git a/mlir/test/Target/Cpp/unary_operators.mlir b/mlir/test/Target/Cpp/unary_operators.mlir new file mode 100644 index 00000000000000..8a89437a41cc50 --- /dev/null +++ b/mlir/test/Target/Cpp/unary_operators.mlir @@ -0,0 +1,12 @@ +// RUN: mlir-translate -mlir-to-cpp %s | FileCheck %s + +func.func @unary(%arg0: i32) -> () { + %0 = emitc.unary_minus %arg0 : (i32) -> i32 + %1 = emitc.unary_plus %arg0 : (i32) -> i32 + + return +} + +// CHECK-LABEL: void unary +// CHECK-NEXT: int32_t [[V1:[^ ]*]] = -[[V0:[^ ]*]]; +// CHECK-NEXT: int32_t [[V2:[^ ]*]] = +[[V0]];