Skip to content

Commit f6d9ef7

Browse files
committed
PackSpecialization: Fix result & type parameter handling
We can avoid issues with loweredType failing by not calling it. We should also only get the SILType for results that we actually intend to explode. Attempting to explode packs containing generic type parameters lead to compiler crashes. It is preferable to wait until generic types have been replaced with concrete types to explode packs, since we can only pass exploded pack elements directly if we know they are loadable. @Substituted types cause the shouldExplode predicate to be unreliable for AST types, so we have restricted it to just SIL types. We have added test cases for functions that have @Substituted types.
1 parent 1bb65d8 commit f6d9ef7

File tree

2 files changed

+220
-70
lines changed

2 files changed

+220
-70
lines changed

SwiftCompilerSources/Sources/Optimizer/FunctionPasses/PackSpecialization.swift

Lines changed: 106 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -353,36 +353,45 @@ private struct CallSiteSpecializer {
353353
var directResults = resultInstruction.results.makeIterator()
354354
var substitutedResultTupleElements = [any Value]()
355355
var mappedResultPacks = self.callee.resultMap.makeIterator()
356+
var indirectResultIterator = self.apply.indirectResultOperands.makeIterator()
356357

357358
for resultInfo in self.apply.functionConvention.results {
358359
// We only need to handle direct and pack results, since indirect results are handled above
359360
if !resultInfo.isSILIndirect {
360361
// Direct results of the original function are mapped to direct results of the specialized function.
361362
substitutedResultTupleElements.append(directResults.next()!)
362363

363-
} else if resultInfo.type.shouldExplode {
364-
// Some elements of pack results may get mapped to direct results of the specialized function.
365-
// We handle those here.
366-
let mapped = mappedResultPacks.next()!
367-
368-
let originalPackArgument = self.apply.arguments[mapped.argumentIndex]
369-
let packIndices = packArgumentIndices[mapped.argumentIndex]!
370-
371-
for (mappedDirectResult, (packIndex, elementType)) in zip(
372-
mapped.expandedElements, zip(packIndices, originalPackArgument.type.packElements)
373-
)
374-
where !mappedDirectResult.isSILIndirect
375-
{
376-
377-
let result = directResults.next()!
378-
let outputResultAddress = builder.createPackElementGet(
379-
packIndex: packIndex, pack: originalPackArgument,
380-
elementType: elementType)
381-
382-
builder.createStore(
383-
source: result, destination: outputResultAddress,
384-
// The callee is responsible for initializing return pack elements
385-
ownership: storeOwnership(for: result, normal: .initialize))
364+
} else {
365+
guard let indirectResult = indirectResultIterator.next()?.value,
366+
indirectResult.type.shouldExplode
367+
else {
368+
continue
369+
}
370+
371+
do {
372+
// Some elements of pack results may get mapped to direct results of the specialized function.
373+
// We handle those here.
374+
let mapped = mappedResultPacks.next()!
375+
376+
377+
let packIndices = packArgumentIndices[mapped.argumentIndex]!
378+
379+
for (mappedDirectResult, (packIndex, elementType)) in zip(
380+
mapped.expandedElements, zip(packIndices, indirectResult.type.packElements)
381+
)
382+
where !mappedDirectResult.isSILIndirect
383+
{
384+
385+
let result = directResults.next()!
386+
let outputResultAddress = builder.createPackElementGet(
387+
packIndex: packIndex, pack: indirectResult,
388+
elementType: elementType)
389+
390+
builder.createStore(
391+
source: result, destination: outputResultAddress,
392+
// The callee is responsible for initializing return pack elements
393+
ownership: storeOwnership(for: result, normal: .initialize))
394+
}
386395
}
387396
}
388397
}
@@ -499,6 +508,14 @@ private struct PackExplodedFunction {
499508
/// Index of this pack in the function's result type tuple.
500509
/// For a non-tuple result, this is 0.
501510
let resultIndex: Int
511+
/// ResultInfo for the results produced by exploding the original result.
512+
///
513+
/// NOTE: The expandedElements members of MappedResult & MappedParameter
514+
/// correspond to slices of the [ResultInfo] and [ParameterInfo] arrays
515+
/// produced at the same time as the ResultMap & ParameterMap respectively.
516+
/// Replacing these members with integer ranges or spans referring to those
517+
/// full arrays could be an easy performance optimization if this pass
518+
/// becomes a bottleneck.
502519
let expandedElements: [ResultInfo]
503520
}
504521

@@ -510,6 +527,7 @@ private struct PackExplodedFunction {
510527
/// order.
511528
struct MappedParameter {
512529
let argumentIndex: Int
530+
/// ParameterInfo for the parameters produced by exploding the original parameter.
513531
let expandedElements: [ParameterInfo]
514532
}
515533

@@ -631,36 +649,42 @@ private struct PackExplodedFunction {
631649
var resultMap = ResultMap()
632650
var newResults = [ResultInfo]()
633651

634-
var indirectResultIdx = 0
635-
for (resultIndex, resultInfo) in function.convention.results.enumerated() {
636-
let silType = resultInfo.type.loweredType(in: function)
637-
if silType.shouldExplode {
638-
let mappedResultInfos = silType.packElements.map { elem in
639-
// TODO: Determine correct values for options and hasLoweredAddress
640-
ResultInfo(
641-
type: elem.canonicalType,
642-
convention: explodedPackElementResultConvention(in: function, type: elem),
643-
options: resultInfo.options,
644-
hasLoweredAddresses: resultInfo.hasLoweredAddresses)
645-
}
652+
var indirectResultIterator = function.arguments[0..<function.convention.indirectSILResultCount]
653+
.lazy.enumerated().makeIterator()
646654

647-
resultMap.append(
648-
MappedResult(
649-
argumentIndex: indirectResultIdx, resultIndex: resultIndex,
650-
expandedElements: mappedResultInfos))
651-
newResults += mappedResultInfos
652-
} else {
653-
// Leave the original result unchanged
655+
for (resultIndex, resultInfo) in function.convention.results.enumerated() {
656+
assert(
657+
!resultInfo.isSILIndirect || !indirectResultIterator.isEmpty,
658+
"There must be exactly as many indirect results in the function convention and argument list."
659+
)
660+
661+
guard resultInfo.isSILIndirect,
662+
// There should always be a value here (expressed by the assert above).
663+
let (indirectResultIdx, indirectResult) = indirectResultIterator.next(),
664+
indirectResult.type.shouldExplode
665+
else {
654666
newResults.append(resultInfo)
667+
continue
655668
}
656669

657-
if resultInfo.isSILIndirect {
658-
indirectResultIdx += 1
670+
let mappedResultInfos = indirectResult.type.packElements.map { elem in
671+
ResultInfo(
672+
type: elem.canonicalType,
673+
convention: explodedPackElementResultConvention(in: function, type: elem),
674+
options: resultInfo.options,
675+
hasLoweredAddresses: resultInfo.hasLoweredAddresses)
659676
}
677+
678+
resultMap.append(
679+
MappedResult(
680+
argumentIndex: indirectResultIdx, resultIndex: resultIndex,
681+
expandedElements: mappedResultInfos))
682+
newResults += mappedResultInfos
683+
660684
}
661685

662686
assert(
663-
indirectResultIdx == function.argumentConventions.firstParameterIndex,
687+
indirectResultIterator.isEmpty,
664688
"We should have walked through all the indirect results, and no further.")
665689

666690
return (newResults, resultMap)
@@ -726,31 +750,36 @@ private struct PackExplodedFunction {
726750
let originalValue = originalReturn.returnedValue
727751

728752
let originalReturnTupleElements: [Value]
729-
if originalValue.type.isTuple {
753+
if originalValue.type.isVoid {
754+
originalReturnTupleElements = []
755+
} else if originalValue.type.isTuple {
730756
originalReturnTupleElements = [Value](
731757
builder.createDestructureTuple(tuple: originalValue).results)
732758
} else {
733759
originalReturnTupleElements = [originalValue]
734760
}
735761

736-
var returnValues = [any Value]()
737-
738762
// Thread together the original and exploded direct return values.
739-
var resultMapIndex = 0
740-
var originalReturnIndex = 0
741-
for (i, originalResult) in self.original.convention.results.enumerated()
742-
where originalResult.type.shouldExplode
743-
|| !originalResult.isSILIndirect
744-
{
745-
if !resultMap.indices.contains(resultMapIndex) || resultMap[resultMapIndex].resultIndex != i {
746-
returnValues.append(originalReturnTupleElements[originalReturnIndex])
747-
originalReturnIndex += 1
748-
749-
} else {
763+
let theReturnValues: [any Value]
764+
do {
765+
var returnValues = [any Value]()
766+
// The next original result to process.
767+
var resultIndex = 0
768+
var originalDirectResultIterator = originalReturnTupleElements.makeIterator()
769+
770+
for mappedResult in resultMap {
771+
772+
// Collect any direct results before the next mappedResult.
773+
while resultIndex < mappedResult.resultIndex {
774+
if !self.original.convention.results[resultIndex].isSILIndirect {
775+
returnValues.append(originalDirectResultIterator.next()!)
776+
}
777+
resultIndex += 1
778+
}
750779

751-
let mapped = resultMap[resultMapIndex]
780+
assert(resultIndex == mappedResult.resultIndex, "The next pack result is not skipped.")
752781

753-
let argumentMapping = argumentMap[mapped.argumentIndex]!
782+
let argumentMapping = argumentMap[mappedResult.argumentIndex]!
754783
for argument in argumentMapping.arguments {
755784

756785
switch argument.resources {
@@ -769,18 +798,26 @@ private struct PackExplodedFunction {
769798
}
770799
}
771800

772-
resultMapIndex += 1
801+
// We have finished processing mappedResult, so step forward.
802+
resultIndex += 1
773803
}
804+
805+
// Collect any remaining original direct results.
806+
while let directResult = originalDirectResultIterator.next() {
807+
returnValues.append(directResult)
808+
}
809+
810+
theReturnValues = returnValues
774811
}
775812

776-
// Return the single value directly, rather than constructing a single-element tuple for it.
777-
if returnValues.count == 1 {
778-
builder.createReturn(of: returnValues[0])
813+
// Return a single return value directly, rather than constructing a single-element tuple for it.
814+
if theReturnValues.count == 1 {
815+
builder.createReturn(of: theReturnValues[0])
779816
} else {
780-
let tupleElementTypes = returnValues.map { $0.type }
817+
let tupleElementTypes = theReturnValues.map { $0.type }
781818
let tupleType = context.getTupleType(elements: tupleElementTypes).loweredType(
782819
in: specialized)
783-
let tuple = builder.createTuple(type: tupleType, elements: returnValues)
820+
let tuple = builder.createTuple(type: tupleType, elements: theReturnValues)
784821
builder.createReturn(of: tuple)
785822
}
786823

@@ -1026,10 +1063,10 @@ private func loadOwnership(for value: any Value, normal: LoadInst.LoadOwnership)
10261063
}
10271064
}
10281065

1029-
extension TypeProperties {
1066+
extension Type {
10301067
/// A pack argument can explode if it contains no pack expansion types
10311068
fileprivate var shouldExplode: Bool {
10321069
// For now, we only attempt to explode indirect packs, since these are the most common and inefficient.
1033-
return isSILPack && !containsSILPackExpansionType && isSILPackElementAddress
1070+
return isSILPack && !hasTypeParameter && !containsSILPackExpansionType && isSILPackElementAddress
10341071
}
10351072
}

test/SILOptimizer/pack_specialization.sil

Lines changed: 114 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,27 @@ bb0(%0 : $*Pack{any P}, %1 : $*Pack{any P}):
377377
return %17
378378
}
379379

380+
// Handling of generic type parameters in packs.
381+
// We should wait until these are fully specialized before exploding.
382+
sil [ossa] @copy_pack_generic : $@convention(thin) <T> (@pack_guaranteed Pack{T}) -> @pack_out Pack{T} {
383+
bb0(%0 : $*Pack{T}, %1 : $*Pack{T}):
384+
%2 = tuple ()
385+
return %2
386+
}
387+
388+
// CHECK-LABEL: sil [ossa] @call_copy_pack_generic : $@convention(thin) <T> (@pack_guaranteed Pack{T}) -> @pack_out Pack{T} {
389+
// CHECK: bb0(%0 : $*Pack{T}, %1 : $*Pack{T})
390+
// CHECK: [[CPG:%[0-9]+]] = function_ref @copy_pack_generic
391+
// CHECK-NEXT: [[RESULT:%[0-9]+]] = apply [[CPG]]
392+
// CHECK-NEXT: return [[RESULT]]
393+
// CHECK-LABEL: } // end sil function 'call_copy_pack_generic'
394+
sil [ossa] @call_copy_pack_generic : $@convention(thin) <T> (@pack_guaranteed Pack{T}) -> @pack_out Pack{T} {
395+
bb0(%0 : $*Pack{T}, %1 : $*Pack{T}):
396+
%16 = function_ref @copy_pack_generic : $@convention(thin) <T> (@pack_guaranteed Pack{T}) -> @pack_out Pack{T}
397+
%17 = apply %16<T>(%0, %1) : $@convention(thin) <T> (@pack_guaranteed Pack{T}) -> @pack_out Pack{T}
398+
return %17
399+
}
400+
380401
// INTERLEAVING PACK AND NON-PACK ARGUMENTS TESTS:
381402
//
382403
// The new function arguments and return values corresponding to pack parameters
@@ -699,7 +720,8 @@ bb0(%0 : $*Pack{Int}):
699720
// BAIL OUT CONDITION TESTS:
700721
//
701722
// Only perform pack specialization on functions that have indirect pack
702-
// parameters that contain no pack expansions.
723+
// parameters and/or results that contain no pack expansions or generic type
724+
// parameters.
703725
//
704726
// 2025-10-15: We currently only explode packs with address elements
705727
// (SILPackType::isElementAddress), because these are the most common (since
@@ -789,3 +811,94 @@ bb0(%0 : $*Pack{Int}):
789811
%2 = apply %1(%0) : $@convention(thin) (@pack_guaranteed Pack{Int}) -> ()
790812
return %2
791813
}
814+
815+
816+
sil [ossa] @indirect_result_pack : $@convention(thin) () -> @pack_out Pack{Int} {
817+
bb0(%0 : $*Pack{Int}):
818+
%1 = tuple ()
819+
return %1
820+
}
821+
822+
sil [ossa] @direct_result_pack : $@convention(thin) () -> @pack_out @direct Pack{Int} {
823+
bb0(%0 : $*@direct Pack{Int}):
824+
%1 = tuple ()
825+
return %1
826+
}
827+
828+
sil [ossa] @result_pack_expansion : $@convention(thin) <each A> () -> @pack_out Pack{repeat each A} {
829+
bb0(%0 : $*Pack{repeat each A}):
830+
%1 = tuple ()
831+
return %1
832+
}
833+
834+
sil [ossa] @mixed_result_packs : $@convention(thin) <each A> () -> (@pack_out Pack{Int}, @pack_out @direct Pack{Int}, @pack_out Pack{repeat each A}) {
835+
bb0(%0 : $*Pack{Int}, %1 : $*@direct Pack{Int}, %2 : $*Pack{repeat each A}):
836+
%3 = tuple ()
837+
return %3
838+
}
839+
840+
// CHECK-LABEL: sil [ossa] @result_call_eligibility_tests : $@convention(thin) <each A> () -> (@pack_out Pack{Int}, @pack_out @direct Pack{Int}, @pack_out Pack{repeat each A}) {
841+
// CHECK: bb0(%0 : $*Pack{Int}, %1 : $*@direct Pack{Int}, %2 : $*Pack{repeat each A}):
842+
// CHECK: function_ref @$s20indirect_result_packTf8x_n : $@convention(thin) () -> Int
843+
// CHECK: function_ref @direct_result_pack : $@convention(thin) () -> @pack_out @direct Pack{Int}
844+
// CHECK: function_ref @result_pack_expansion : $@convention(thin) <each τ_0_0> () -> @pack_out Pack{repeat each τ_0_0}
845+
// CHECK: function_ref @$s18mixed_result_packsTf8xnn_n : $@convention(thin) <each τ_0_0> () -> (Int, @pack_out @direct Pack{Int}, @pack_out Pack{repeat each τ_0_0})
846+
// CHECK-LABEL: } // end sil function 'result_call_eligibility_tests'
847+
sil [ossa] @result_call_eligibility_tests : $@convention(thin) <each A> () -> (@pack_out Pack{Int}, @pack_out @direct Pack{Int}, @pack_out Pack{repeat each A}) {
848+
bb0(%1 : $*Pack{Int}, %2 : $*@direct Pack{Int}, %3 : $*Pack{repeat each A}):
849+
850+
%4 = function_ref @indirect_result_pack : $@convention(thin) () -> @pack_out Pack{Int}
851+
%5 = apply %4(%1) : $@convention(thin) () -> @pack_out Pack{Int}
852+
853+
%8 = function_ref @direct_result_pack : $@convention(thin) () -> @pack_out @direct Pack{Int}
854+
%9 = apply %8(%2) : $@convention(thin) () -> @pack_out @direct Pack{Int}
855+
856+
%10 = function_ref @result_pack_expansion : $@convention(thin) <each A> () -> @pack_out Pack{repeat each A}
857+
%11 = apply %10<Pack{repeat each A}>(%3) : $@convention(thin) <each A> () -> @pack_out Pack{repeat each A}
858+
859+
%12 = function_ref @mixed_result_packs : $@convention(thin) <each A> () -> (@pack_out Pack{Int}, @pack_out @direct Pack{Int}, @pack_out Pack{repeat each A})
860+
%13 = apply %12<Pack{repeat each A}>(%1, %2, %3) : $@convention(thin) <each A> () -> (@pack_out Pack{Int}, @pack_out @direct Pack{Int}, @pack_out Pack{repeat each A})
861+
862+
%99 = tuple ()
863+
return %99
864+
}
865+
866+
// @substituted TYPE TESTS:
867+
//
868+
// PackSubstitution must correctly identify the types of Pack arguments in the presence of @substituted types.
869+
870+
// CHECK-LABEL: sil shared [ossa] @$s25substitute_parameter_typeTf8x_n : $@convention(thin) @substituted <τ_0_0> (Int32) -> () for <Int32> {
871+
// CHECK: bb0(%0 : $Int32):
872+
// CHECK-LABEL: } // end sil function '$s25substitute_parameter_typeTf8x_n'
873+
sil [ossa] @substitute_parameter_type : $@convention(thin) @substituted <T> (@pack_guaranteed Pack{T}) -> () for <Int32> {
874+
bb0(%0 : $*Pack{Int32}):
875+
%99 = tuple ()
876+
return %99
877+
}
878+
879+
// CHECK-LABEL: sil [ossa] @call_substitute_parameter_type : $@convention(thin) (@pack_guaranteed Pack{Int32}) -> () {
880+
// CHECK-LABEL: } // end sil function 'call_substitute_parameter_type'
881+
sil [ossa] @call_substitute_parameter_type : $@convention(thin) (@pack_guaranteed Pack{Int32}) -> () {
882+
bb0(%0 : $*Pack{Int32}):
883+
%1 = function_ref @substitute_parameter_type : $@convention(thin) @substituted <T> (@pack_guaranteed Pack{T}) -> () for <Int32>
884+
%2 = apply %1(%0) : $@convention(thin) @substituted <T> (@pack_guaranteed Pack{T}) -> () for <Int32>
885+
return %2
886+
}
887+
888+
// CHECK-LABEL: sil shared [ossa] @$s22substitute_result_typeTf8x_n : $@convention(thin) @substituted <τ_0_0> () -> Int32 for <Int32> {
889+
// CHECK: bb0:
890+
// CHECK-LABEL: } // end sil function '$s22substitute_result_typeTf8x_n'
891+
sil [ossa] @substitute_result_type : $@convention(thin) @substituted <T> () -> @pack_out Pack{T} for <Int32> {
892+
bb0(%0 : $*Pack{Int32}):
893+
%99 = tuple ()
894+
return %99
895+
}
896+
897+
// CHECK-LABEL: sil [ossa] @call_substitute_result_type : $@convention(thin) () -> @pack_out Pack{Int32} {
898+
// CHECK-LABEL: } // end sil function 'call_substitute_result_type'
899+
sil [ossa] @call_substitute_result_type : $@convention(thin) () -> @pack_out Pack{Int32} {
900+
bb0(%0 : $*Pack{Int32}):
901+
%1 = function_ref @substitute_result_type : $@convention(thin) @substituted <T> () -> @pack_out Pack{T} for <Int32>
902+
%2 = apply %1(%0) : $@convention(thin) @substituted <T> () -> @pack_out Pack{T} for <Int32>
903+
return %2
904+
}

0 commit comments

Comments
 (0)