Skip to content

Commit 125c351

Browse files
committed
PackSpecialization: Refactor with @Substituted support
@Substituted types cause the shouldExplode predicate to be unreliable for AST types, so we have restricted it to just SIL types.
1 parent 2d8046a commit 125c351

File tree

2 files changed

+145
-68
lines changed

2 files changed

+145
-68
lines changed

SwiftCompilerSources/Sources/Optimizer/FunctionPasses/PackSpecialization.swift

Lines changed: 105 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -353,36 +353,45 @@ private struct CallSiteSpecializer {
353353
var directResults = resultInstruction.results.makeIterator()
354354
var substitutedResultTupleElements = [any Value]()
355355
var mappedResultPacks = self.callee.resultMap.makeIterator()
356+
var indirectResultIterator = self.apply.indirectResultOperands.makeIterator()
356357

357358
for resultInfo in self.apply.functionConvention.results {
358359
// We only need to handle direct and pack results, since indirect results are handled above
359360
if !resultInfo.isSILIndirect {
360361
// Direct results of the original function are mapped to direct results of the specialized function.
361362
substitutedResultTupleElements.append(directResults.next()!)
362363

363-
} else if resultInfo.type.shouldExplode {
364-
// Some elements of pack results may get mapped to direct results of the specialized function.
365-
// We handle those here.
366-
let mapped = mappedResultPacks.next()!
367-
368-
let originalPackArgument = self.apply.arguments[mapped.argumentIndex]
369-
let packIndices = packArgumentIndices[mapped.argumentIndex]!
370-
371-
for (mappedDirectResult, (packIndex, elementType)) in zip(
372-
mapped.expandedElements, zip(packIndices, originalPackArgument.type.packElements)
373-
)
374-
where !mappedDirectResult.isSILIndirect
375-
{
376-
377-
let result = directResults.next()!
378-
let outputResultAddress = builder.createPackElementGet(
379-
packIndex: packIndex, pack: originalPackArgument,
380-
elementType: elementType)
381-
382-
builder.createStore(
383-
source: result, destination: outputResultAddress,
384-
// The callee is responsible for initializing return pack elements
385-
ownership: storeOwnership(for: result, normal: .initialize))
364+
} else {
365+
guard let indirectResult = indirectResultIterator.next()?.value,
366+
indirectResult.type.shouldExplode
367+
else {
368+
continue
369+
}
370+
371+
do {
372+
// Some elements of pack results may get mapped to direct results of the specialized function.
373+
// We handle those here.
374+
let mapped = mappedResultPacks.next()!
375+
376+
377+
let packIndices = packArgumentIndices[mapped.argumentIndex]!
378+
379+
for (mappedDirectResult, (packIndex, elementType)) in zip(
380+
mapped.expandedElements, zip(packIndices, indirectResult.type.packElements)
381+
)
382+
where !mappedDirectResult.isSILIndirect
383+
{
384+
385+
let result = directResults.next()!
386+
let outputResultAddress = builder.createPackElementGet(
387+
packIndex: packIndex, pack: indirectResult,
388+
elementType: elementType)
389+
390+
builder.createStore(
391+
source: result, destination: outputResultAddress,
392+
// The callee is responsible for initializing return pack elements
393+
ownership: storeOwnership(for: result, normal: .initialize))
394+
}
386395
}
387396
}
388397
}
@@ -499,6 +508,14 @@ private struct PackExplodedFunction {
499508
/// Index of this pack in the function's result type tuple.
500509
/// For a non-tuple result, this is 0.
501510
let resultIndex: Int
511+
/// ResultInfo for the results produced by exploding the original result.
512+
///
513+
/// NOTE: The expandedElements members of MappedResult & MappedParameter
514+
/// correspond to slices of the [ResultInfo] and [ParameterInfo] arrays
515+
/// produced at the same time as the ResultMap & ParameterMap respectively.
516+
/// Replacing these members with integer ranges or spans referring to those
517+
/// full arrays could be an easy performance optimization if this pass
518+
/// becomes a bottleneck.
502519
let expandedElements: [ResultInfo]
503520
}
504521

@@ -510,6 +527,7 @@ private struct PackExplodedFunction {
510527
/// order.
511528
struct MappedParameter {
512529
let argumentIndex: Int
530+
/// ParameterInfo for the parameters produced by exploding the original parameter.
513531
let expandedElements: [ParameterInfo]
514532
}
515533

@@ -631,36 +649,42 @@ private struct PackExplodedFunction {
631649
var resultMap = ResultMap()
632650
var newResults = [ResultInfo]()
633651

634-
var indirectResultIdx = 0
635-
for (resultIndex, resultInfo) in function.convention.results.enumerated() {
636-
if resultInfo.type.shouldExplode {
637-
let silType = function.arguments[indirectResultIdx].type
638-
639-
let mappedResultInfos = silType.packElements.map { elem in
640-
ResultInfo(
641-
type: elem.canonicalType,
642-
convention: explodedPackElementResultConvention(in: function, type: elem),
643-
options: resultInfo.options,
644-
hasLoweredAddresses: resultInfo.hasLoweredAddresses)
645-
}
652+
var indirectResultIterator = function.arguments[0..<function.convention.indirectSILResultCount]
653+
.lazy.enumerated().makeIterator()
646654

647-
resultMap.append(
648-
MappedResult(
649-
argumentIndex: indirectResultIdx, resultIndex: resultIndex,
650-
expandedElements: mappedResultInfos))
651-
newResults += mappedResultInfos
652-
} else {
653-
// Leave the original result unchanged
655+
for (resultIndex, resultInfo) in function.convention.results.enumerated() {
656+
assert(
657+
!resultInfo.isSILIndirect || !indirectResultIterator.isEmpty,
658+
"There must be exactly as many indirect results in the function convention and argument list."
659+
)
660+
661+
guard resultInfo.isSILIndirect,
662+
// There should always be a value here (expressed by the assert above).
663+
let (indirectResultIdx, indirectResult) = indirectResultIterator.next(),
664+
indirectResult.type.shouldExplode
665+
else {
654666
newResults.append(resultInfo)
667+
continue
655668
}
656669

657-
if resultInfo.isSILIndirect {
658-
indirectResultIdx += 1
670+
let mappedResultInfos = indirectResult.type.packElements.map { elem in
671+
ResultInfo(
672+
type: elem.canonicalType,
673+
convention: explodedPackElementResultConvention(in: function, type: elem),
674+
options: resultInfo.options,
675+
hasLoweredAddresses: resultInfo.hasLoweredAddresses)
659676
}
677+
678+
resultMap.append(
679+
MappedResult(
680+
argumentIndex: indirectResultIdx, resultIndex: resultIndex,
681+
expandedElements: mappedResultInfos))
682+
newResults += mappedResultInfos
683+
660684
}
661685

662686
assert(
663-
indirectResultIdx == function.argumentConventions.firstParameterIndex,
687+
indirectResultIterator.isEmpty,
664688
"We should have walked through all the indirect results, and no further.")
665689

666690
return (newResults, resultMap)
@@ -726,31 +750,36 @@ private struct PackExplodedFunction {
726750
let originalValue = originalReturn.returnedValue
727751

728752
let originalReturnTupleElements: [Value]
729-
if originalValue.type.isTuple {
753+
if originalValue.type.isVoid {
754+
originalReturnTupleElements = []
755+
} else if originalValue.type.isTuple {
730756
originalReturnTupleElements = [Value](
731757
builder.createDestructureTuple(tuple: originalValue).results)
732758
} else {
733759
originalReturnTupleElements = [originalValue]
734760
}
735761

736-
var returnValues = [any Value]()
737-
738762
// Thread together the original and exploded direct return values.
739-
var resultMapIndex = 0
740-
var originalReturnIndex = 0
741-
for (i, originalResult) in self.original.convention.results.enumerated()
742-
where originalResult.type.shouldExplode
743-
|| !originalResult.isSILIndirect
744-
{
745-
if !resultMap.indices.contains(resultMapIndex) || resultMap[resultMapIndex].resultIndex != i {
746-
returnValues.append(originalReturnTupleElements[originalReturnIndex])
747-
originalReturnIndex += 1
748-
749-
} else {
763+
let theReturnValues: [any Value]
764+
do {
765+
var returnValues = [any Value]()
766+
// The next original result to process.
767+
var resultIndex = 0
768+
var originalDirectResultIterator = originalReturnTupleElements.makeIterator()
769+
770+
for mappedResult in resultMap {
771+
772+
// Collect any direct results before the next mappedResult.
773+
while resultIndex < mappedResult.resultIndex {
774+
if !self.original.convention.results[resultIndex].isSILIndirect {
775+
returnValues.append(originalDirectResultIterator.next()!)
776+
}
777+
resultIndex += 1
778+
}
750779

751-
let mapped = resultMap[resultMapIndex]
780+
assert(resultIndex == mappedResult.resultIndex, "The next pack result is not skipped.")
752781

753-
let argumentMapping = argumentMap[mapped.argumentIndex]!
782+
let argumentMapping = argumentMap[mappedResult.argumentIndex]!
754783
for argument in argumentMapping.arguments {
755784

756785
switch argument.resources {
@@ -769,18 +798,26 @@ private struct PackExplodedFunction {
769798
}
770799
}
771800

772-
resultMapIndex += 1
801+
// We have finished processing mappedResult, so step forward.
802+
resultIndex += 1
773803
}
804+
805+
// Collect any remaining original direct results.
806+
while let directResult = originalDirectResultIterator.next() {
807+
returnValues.append(directResult)
808+
}
809+
810+
theReturnValues = returnValues
774811
}
775812

776-
// Return the single value directly, rather than constructing a single-element tuple for it.
777-
if returnValues.count == 1 {
778-
builder.createReturn(of: returnValues[0])
813+
// Return a single return value directly, rather than constructing a single-element tuple for it.
814+
if theReturnValues.count == 1 {
815+
builder.createReturn(of: theReturnValues[0])
779816
} else {
780-
let tupleElementTypes = returnValues.map { $0.type }
817+
let tupleElementTypes = theReturnValues.map { $0.type }
781818
let tupleType = context.getTupleType(elements: tupleElementTypes).loweredType(
782819
in: specialized)
783-
let tuple = builder.createTuple(type: tupleType, elements: returnValues)
820+
let tuple = builder.createTuple(type: tupleType, elements: theReturnValues)
784821
builder.createReturn(of: tuple)
785822
}
786823

@@ -1026,7 +1063,7 @@ private func loadOwnership(for value: any Value, normal: LoadInst.LoadOwnership)
10261063
}
10271064
}
10281065

1029-
extension TypeProperties {
1066+
extension Type {
10301067
/// A pack argument can explode if it contains no pack expansion types
10311068
fileprivate var shouldExplode: Bool {
10321069
// For now, we only attempt to explode indirect packs, since these are the most common and inefficient.

test/SILOptimizer/pack_specialization.sil

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -862,3 +862,43 @@ bb0(%1 : $*Pack{Int}, %2 : $*@direct Pack{Int}, %3 : $*Pack{repeat each A}):
862862
%99 = tuple ()
863863
return %99
864864
}
865+
866+
// @substituted TYPE TESTS:
867+
//
868+
// PackSubstitution must correctly identify the types of Pack arguments in the presence of @substituted types.
869+
870+
// CHECK-LABEL: sil shared [ossa] @$s25substitute_parameter_typeTf8x_n : $@convention(thin) @substituted <τ_0_0> (Int32) -> () for <Int32> {
871+
// CHECK: bb0(%0 : $Int32):
872+
// CHECK-LABEL: } // end sil function '$s25substitute_parameter_typeTf8x_n'
873+
sil [ossa] @substitute_parameter_type : $@convention(thin) @substituted <T> (@pack_guaranteed Pack{T}) -> () for <Int32> {
874+
bb0(%0 : $*Pack{Int32}):
875+
%99 = tuple ()
876+
return %99
877+
}
878+
879+
// CHECK-LABEL: sil [ossa] @call_substitute_parameter_type : $@convention(thin) (@pack_guaranteed Pack{Int32}) -> () {
880+
// CHECK-LABEL: } // end sil function 'call_substitute_parameter_type'
881+
sil [ossa] @call_substitute_parameter_type : $@convention(thin) (@pack_guaranteed Pack{Int32}) -> () {
882+
bb0(%0 : $*Pack{Int32}):
883+
%1 = function_ref @substitute_parameter_type : $@convention(thin) @substituted <T> (@pack_guaranteed Pack{T}) -> () for <Int32>
884+
%2 = apply %1(%0) : $@convention(thin) @substituted <T> (@pack_guaranteed Pack{T}) -> () for <Int32>
885+
return %2
886+
}
887+
888+
// CHECK-LABEL: sil shared [ossa] @$s22substitute_result_typeTf8x_n : $@convention(thin) @substituted <τ_0_0> () -> Int32 for <Int32> {
889+
// CHECK: bb0:
890+
// CHECK-LABEL: } // end sil function '$s22substitute_result_typeTf8x_n'
891+
sil [ossa] @substitute_result_type : $@convention(thin) @substituted <T> () -> @pack_out Pack{T} for <Int32> {
892+
bb0(%0 : $*Pack{Int32}):
893+
%99 = tuple ()
894+
return %99
895+
}
896+
897+
// CHECK-LABEL: sil [ossa] @call_substitute_result_type : $@convention(thin) () -> @pack_out Pack{Int32} {
898+
// CHECK-LABEL: } // end sil function 'call_substitute_result_type'
899+
sil [ossa] @call_substitute_result_type : $@convention(thin) () -> @pack_out Pack{Int32} {
900+
bb0(%0 : $*Pack{Int32}):
901+
%1 = function_ref @substitute_result_type : $@convention(thin) @substituted <T> () -> @pack_out Pack{T} for <Int32>
902+
%2 = apply %1(%0) : $@convention(thin) @substituted <T> () -> @pack_out Pack{T} for <Int32>
903+
return %2
904+
}

0 commit comments

Comments
 (0)