Skip to content

Commit c8f168c

Browse files
authored
[SandboxIR] Remove tight-coupling with LLVM's SwitchInst::CaseHandle (#167093)
SandboxIR's SwitchInst CaseHandle was relying on LLVM IR's SwitchInst::CaseHandleImpl template, which may call private functions of SandboxIR's SwitchInst. This creates a dependency cycle which is against the design principles of Sandbox IR. The issue was exposed by: #166842 Thanks to @aengelke for raising the issue.
1 parent 36e9a0b commit c8f168c

File tree

2 files changed

+119
-21
lines changed

2 files changed

+119
-21
lines changed

llvm/include/llvm/SandboxIR/Instruction.h

Lines changed: 92 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1884,45 +1884,116 @@ class SwitchInst : public SingleLLVMInstructionImpl<llvm::SwitchInst> {
18841884
return cast<llvm::SwitchInst>(Val)->getNumCases();
18851885
}
18861886

1887+
template <typename LLVMCaseItT, typename BlockT, typename ConstT>
1888+
class CaseItImpl;
1889+
1890+
// The template helps avoid code duplication for const and non-const
1891+
// CaseHandle variants.
1892+
template <typename LLVMCaseItT, typename BlockT, typename ConstT>
1893+
class CaseHandleImpl {
1894+
Context &Ctx;
1895+
// NOTE: We are not wrapping an LLVM CaseHande here because it is not
1896+
// default-constructible. Instead we are wrapping the LLVM CaseIt
1897+
// iterator, as we can always get an LLVM CaseHandle by de-referencing it.
1898+
LLVMCaseItT LLVMCaseIt;
1899+
template <typename T1, typename T2, typename T3> friend class CaseItImpl;
1900+
1901+
public:
1902+
CaseHandleImpl(Context &Ctx, LLVMCaseItT LLVMCaseIt)
1903+
: Ctx(Ctx), LLVMCaseIt(LLVMCaseIt) {}
1904+
ConstT *getCaseValue() const;
1905+
BlockT *getCaseSuccessor() const;
1906+
unsigned getCaseIndex() const {
1907+
const auto &LLVMCaseHandle = *LLVMCaseIt;
1908+
return LLVMCaseHandle.getCaseIndex();
1909+
}
1910+
unsigned getSuccessorIndex() const {
1911+
const auto &LLVMCaseHandle = *LLVMCaseIt;
1912+
return LLVMCaseHandle.getSuccessorIndex();
1913+
}
1914+
};
1915+
1916+
// The template helps avoid code duplication for const and non-const CaseIt
1917+
// variants.
1918+
template <typename LLVMCaseItT, typename BlockT, typename ConstT>
1919+
class CaseItImpl : public iterator_facade_base<
1920+
CaseItImpl<LLVMCaseItT, BlockT, ConstT>,
1921+
std::random_access_iterator_tag,
1922+
const CaseHandleImpl<LLVMCaseItT, BlockT, ConstT>> {
1923+
CaseHandleImpl<LLVMCaseItT, BlockT, ConstT> CH;
1924+
1925+
public:
1926+
CaseItImpl(Context &Ctx, LLVMCaseItT It) : CH(Ctx, It) {}
1927+
CaseItImpl(SwitchInst *SI, ptrdiff_t CaseNum)
1928+
: CH(SI->getContext(), llvm::SwitchInst::CaseIt(
1929+
cast<llvm::SwitchInst>(SI->Val), CaseNum)) {}
1930+
CaseItImpl &operator+=(ptrdiff_t N) {
1931+
CH.LLVMCaseIt += N;
1932+
return *this;
1933+
}
1934+
CaseItImpl &operator-=(ptrdiff_t N) {
1935+
CH.LLVMCaseIt -= N;
1936+
return *this;
1937+
}
1938+
ptrdiff_t operator-(const CaseItImpl &Other) const {
1939+
return CH.LLVMCaseIt - Other.CH.LLVMCaseIt;
1940+
}
1941+
bool operator==(const CaseItImpl &Other) const {
1942+
return CH.LLVMCaseIt == Other.CH.LLVMCaseIt;
1943+
}
1944+
bool operator<(const CaseItImpl &Other) const {
1945+
return CH.LLVMCaseIt < Other.CH.LLVMCaseIt;
1946+
}
1947+
const CaseHandleImpl<LLVMCaseItT, BlockT, ConstT> &operator*() const {
1948+
return CH;
1949+
}
1950+
};
1951+
18871952
using CaseHandle =
1888-
llvm::SwitchInst::CaseHandleImpl<SwitchInst, ConstantInt, BasicBlock>;
1889-
using ConstCaseHandle =
1890-
llvm::SwitchInst::CaseHandleImpl<const SwitchInst, const ConstantInt,
1891-
const BasicBlock>;
1892-
using CaseIt = llvm::SwitchInst::CaseIteratorImpl<CaseHandle>;
1893-
using ConstCaseIt = llvm::SwitchInst::CaseIteratorImpl<ConstCaseHandle>;
1953+
CaseHandleImpl<llvm::SwitchInst::CaseIt, BasicBlock, ConstantInt>;
1954+
using CaseIt = CaseItImpl<llvm::SwitchInst::CaseIt, BasicBlock, ConstantInt>;
1955+
1956+
using ConstCaseHandle = CaseHandleImpl<llvm::SwitchInst::ConstCaseIt,
1957+
const BasicBlock, const ConstantInt>;
1958+
using ConstCaseIt = CaseItImpl<llvm::SwitchInst::ConstCaseIt,
1959+
const BasicBlock, const ConstantInt>;
18941960

18951961
/// Returns a read/write iterator that points to the first case in the
18961962
/// SwitchInst.
1897-
CaseIt case_begin() { return CaseIt(this, 0); }
1898-
ConstCaseIt case_begin() const { return ConstCaseIt(this, 0); }
1963+
CaseIt case_begin() {
1964+
return CaseIt(Ctx, cast<llvm::SwitchInst>(Val)->case_begin());
1965+
}
1966+
ConstCaseIt case_begin() const {
1967+
return ConstCaseIt(Ctx, cast<llvm::SwitchInst>(Val)->case_begin());
1968+
}
18991969
/// Returns a read/write iterator that points one past the last in the
19001970
/// SwitchInst.
1901-
CaseIt case_end() { return CaseIt(this, getNumCases()); }
1902-
ConstCaseIt case_end() const { return ConstCaseIt(this, getNumCases()); }
1971+
CaseIt case_end() {
1972+
return CaseIt(Ctx, cast<llvm::SwitchInst>(Val)->case_end());
1973+
}
1974+
ConstCaseIt case_end() const {
1975+
return ConstCaseIt(Ctx, cast<llvm::SwitchInst>(Val)->case_end());
1976+
}
19031977
/// Iteration adapter for range-for loops.
19041978
iterator_range<CaseIt> cases() {
19051979
return make_range(case_begin(), case_end());
19061980
}
19071981
iterator_range<ConstCaseIt> cases() const {
19081982
return make_range(case_begin(), case_end());
19091983
}
1910-
CaseIt case_default() { return CaseIt(this, DefaultPseudoIndex); }
1984+
CaseIt case_default() {
1985+
return CaseIt(Ctx, cast<llvm::SwitchInst>(Val)->case_default());
1986+
}
19111987
ConstCaseIt case_default() const {
1912-
return ConstCaseIt(this, DefaultPseudoIndex);
1988+
return ConstCaseIt(Ctx, cast<llvm::SwitchInst>(Val)->case_default());
19131989
}
19141990
CaseIt findCaseValue(const ConstantInt *C) {
1915-
return CaseIt(
1916-
this,
1917-
const_cast<const SwitchInst *>(this)->findCaseValue(C)->getCaseIndex());
1991+
const llvm::ConstantInt *LLVMC = cast<llvm::ConstantInt>(C->Val);
1992+
return CaseIt(Ctx, cast<llvm::SwitchInst>(Val)->findCaseValue(LLVMC));
19181993
}
19191994
ConstCaseIt findCaseValue(const ConstantInt *C) const {
1920-
ConstCaseIt I = llvm::find_if(cases(), [C](const ConstCaseHandle &Case) {
1921-
return Case.getCaseValue() == C;
1922-
});
1923-
if (I != case_end())
1924-
return I;
1925-
return case_default();
1995+
const llvm::ConstantInt *LLVMC = cast<llvm::ConstantInt>(C->Val);
1996+
return ConstCaseIt(Ctx, cast<llvm::SwitchInst>(Val)->findCaseValue(LLVMC));
19261997
}
19271998
LLVM_ABI ConstantInt *findCaseDest(BasicBlock *BB);
19281999

llvm/lib/SandboxIR/Instruction.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1125,6 +1125,33 @@ void SwitchInst::setDefaultDest(BasicBlock *DefaultCase) {
11251125
cast<llvm::SwitchInst>(Val)->setDefaultDest(
11261126
cast<llvm::BasicBlock>(DefaultCase->Val));
11271127
}
1128+
1129+
template <typename LLVMCaseItT, typename BlockT, typename ConstT>
1130+
ConstT *
1131+
SwitchInst::CaseHandleImpl<LLVMCaseItT, BlockT, ConstT>::getCaseValue() const {
1132+
const auto &LLVMCaseHandle = *LLVMCaseIt;
1133+
auto *LLVMC = Ctx.getValue(LLVMCaseHandle.getCaseValue());
1134+
return cast<ConstT>(LLVMC);
1135+
}
1136+
1137+
template <typename LLVMCaseItT, typename BlockT, typename ConstT>
1138+
BlockT *
1139+
SwitchInst::CaseHandleImpl<LLVMCaseItT, BlockT, ConstT>::getCaseSuccessor()
1140+
const {
1141+
const auto &LLVMCaseHandle = *LLVMCaseIt;
1142+
auto *LLVMBB = LLVMCaseHandle.getCaseSuccessor();
1143+
return cast<BlockT>(Ctx.getValue(LLVMBB));
1144+
}
1145+
1146+
template class SwitchInst::CaseHandleImpl<llvm::SwitchInst::CaseIt, BasicBlock,
1147+
ConstantInt>;
1148+
template class SwitchInst::CaseItImpl<llvm::SwitchInst::CaseIt, BasicBlock,
1149+
ConstantInt>;
1150+
template class SwitchInst::CaseHandleImpl<llvm::SwitchInst::ConstCaseIt,
1151+
const BasicBlock, const ConstantInt>;
1152+
template class SwitchInst::CaseItImpl<llvm::SwitchInst::ConstCaseIt,
1153+
const BasicBlock, const ConstantInt>;
1154+
11281155
ConstantInt *SwitchInst::findCaseDest(BasicBlock *BB) {
11291156
auto *LLVMC = cast<llvm::SwitchInst>(Val)->findCaseDest(
11301157
cast<llvm::BasicBlock>(BB->Val));

0 commit comments

Comments
 (0)