Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions include/swift/ABI/TrailingObjects.h
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,15 @@ class swift_ptrauth_struct_derived(BaseTy) TrailingObjects
}
};

// Helper function to determine at build time if a type has TrailingObjects.
// This is useful for determining if trailingTypeCount and
// sizeWithTrailingTypeCount are available for code that reads TrailingObjects
// values in a generalized fashion.
template <typename T>
static constexpr bool typeHasTrailingObjects() {
return std::is_base_of_v<trailing_objects_internal::TrailingObjectsBase, T>;
}

} // end namespace ABI
} // end namespace swift

Expand Down
176 changes: 80 additions & 96 deletions include/swift/Remote/MetadataReader.h
Original file line number Diff line number Diff line change
Expand Up @@ -1493,33 +1493,33 @@ class MetadataReader {
break;
case ContextDescriptorKind::Extension:
success =
readFullContextDescriptor<TargetExtensionContextDescriptor<Runtime>>(
remoteAddress, ptr);
readFullTrailingObjects<TargetExtensionContextDescriptor<Runtime>>(
remoteAddress, ptr, sizeof(TargetContextDescriptor<Runtime>));
break;
case ContextDescriptorKind::Anonymous:
success =
readFullContextDescriptor<TargetAnonymousContextDescriptor<Runtime>>(
remoteAddress, ptr);
readFullTrailingObjects<TargetAnonymousContextDescriptor<Runtime>>(
remoteAddress, ptr, sizeof(TargetContextDescriptor<Runtime>));
break;
case ContextDescriptorKind::Class:
success = readFullContextDescriptor<TargetClassDescriptor<Runtime>>(
remoteAddress, ptr);
success = readFullTrailingObjects<TargetClassDescriptor<Runtime>>(
remoteAddress, ptr, sizeof(TargetContextDescriptor<Runtime>));
break;
case ContextDescriptorKind::Enum:
success = readFullContextDescriptor<TargetEnumDescriptor<Runtime>>(
remoteAddress, ptr);
success = readFullTrailingObjects<TargetEnumDescriptor<Runtime>>(
remoteAddress, ptr, sizeof(TargetContextDescriptor<Runtime>));
break;
case ContextDescriptorKind::Struct:
success = readFullContextDescriptor<TargetStructDescriptor<Runtime>>(
remoteAddress, ptr);
success = readFullTrailingObjects<TargetStructDescriptor<Runtime>>(
remoteAddress, ptr, sizeof(TargetContextDescriptor<Runtime>));
break;
case ContextDescriptorKind::Protocol:
success = readFullContextDescriptor<TargetProtocolDescriptor<Runtime>>(
remoteAddress, ptr);
success = readFullTrailingObjects<TargetProtocolDescriptor<Runtime>>(
remoteAddress, ptr, sizeof(TargetContextDescriptor<Runtime>));
break;
case ContextDescriptorKind::OpaqueType:
success = readFullContextDescriptor<TargetOpaqueTypeDescriptor<Runtime>>(
remoteAddress, ptr);
success = readFullTrailingObjects<TargetOpaqueTypeDescriptor<Runtime>>(
remoteAddress, ptr, sizeof(TargetContextDescriptor<Runtime>));
break;
default:
// We don't know about this kind of context.
Expand All @@ -1535,12 +1535,18 @@ class MetadataReader {
return ContextDescriptorRef(remoteAddress, descriptor);
}

template <typename DescriptorTy>
bool readFullContextDescriptor(RemoteAddress address,
MemoryReader::ReadBytesResult &ptr) {
/// Read all memory occupied by a value with TrailingObjects. This will
/// incrementally read pieces of the object to figure out the full size of it.
/// - address: The address of the value.
/// - ptr: The bytes that have been read so far. On return, the full object.
/// - existingByteCount: The number of bytes in ptr.
template <typename BaseTy>
bool readFullTrailingObjects(RemoteAddress address,
MemoryReader::ReadBytesResult &ptr,
size_t existingByteCount) {
// Read the full base descriptor if it's bigger than what we have so far.
if (sizeof(DescriptorTy) > sizeof(TargetContextDescriptor<Runtime>)) {
ptr = Reader->template readObj<DescriptorTy>(address);
if (sizeof(BaseTy) > existingByteCount) {
ptr = Reader->template readObj<BaseTy>(address);
if (!ptr)
return false;
}
Expand All @@ -1556,13 +1562,17 @@ class MetadataReader {
// size. Once we've walked through all the trailing objects, we've read
// everything.

size_t sizeSoFar = sizeof(DescriptorTy);
size_t sizeSoFar = sizeof(BaseTy);

for (size_t i = 0; i < DescriptorTy::trailingTypeCount(); i++) {
const DescriptorTy *descriptorSoFar =
reinterpret_cast<const DescriptorTy *>(ptr.get());
for (size_t i = 0; i < BaseTy::trailingTypeCount(); i++) {
const BaseTy *descriptorSoFar =
reinterpret_cast<const BaseTy *>(ptr.get());
size_t thisSize = descriptorSoFar->sizeWithTrailingTypeCount(i);
if (thisSize > sizeSoFar) {
// Make sure we haven't ended up with a ridiculous size.
if (thisSize > MaxMetadataSize)
return false;

ptr = Reader->readBytes(address, thisSize);
if (!ptr)
return false;
Expand Down Expand Up @@ -2141,45 +2151,22 @@ class MetadataReader {

switch (getEnumeratedMetadataKind(KindValue)) {
case MetadataKind::Class:

return _readMetadata<TargetClassMetadataType>(address);

return _readMetadataFixedSize<TargetClassMetadataType>(address);
case MetadataKind::Enum:
return _readMetadata<TargetEnumMetadata>(address);
return _readMetadataFixedSize<TargetEnumMetadata>(address);
case MetadataKind::ErrorObject:
return _readMetadata<TargetEnumMetadata>(address);
case MetadataKind::Existential: {
RemoteAddress flagsAddress = address + sizeof(StoredPointer);

ExistentialTypeFlags::int_type flagsData;
if (!Reader->readInteger(flagsAddress, &flagsData))
return nullptr;

ExistentialTypeFlags flags(flagsData);

RemoteAddress numProtocolsAddress = flagsAddress + sizeof(flagsData);
uint32_t numProtocols;
if (!Reader->readInteger(numProtocolsAddress, &numProtocols))
return nullptr;

// Make sure the number of protocols is reasonable
if (numProtocols >= 256)
return nullptr;

auto totalSize = sizeof(TargetExistentialTypeMetadata<Runtime>)
+ numProtocols *
sizeof(ConstTargetMetadataPointer<Runtime, TargetProtocolDescriptor>);

if (flags.hasSuperclassConstraint())
totalSize += sizeof(StoredPointer);

return _readMetadata(address, totalSize);
}
return _readMetadataFixedSize<TargetMetadata>(address);
case MetadataKind::Existential:
return _readMetadataVariableSize<TargetExistentialTypeMetadata>(
address);
case MetadataKind::ExistentialMetatype:
return _readMetadata<TargetExistentialMetatypeMetadata>(address);
return _readMetadataFixedSize<TargetExistentialMetatypeMetadata>(
address);
case MetadataKind::ExtendedExistential: {
// We need to read the shape in order to figure out how large
// the generalization arguments are.
// the generalization arguments are. This prevents us from using
// _readMetadataVariableSize, which requires the Shape field to be
// dereferenceable here.
RemoteAddress shapeAddress = address + sizeof(StoredPointer);
RemoteAddress signedShapePtr;
if (!Reader->template readRemoteAddress<StoredPointer>(shapeAddress,
Expand All @@ -2198,46 +2185,24 @@ class MetadataReader {
return _readMetadata(address, totalSize);
}
case MetadataKind::ForeignClass:
return _readMetadata<TargetForeignClassMetadata>(address);
return _readMetadataFixedSize<TargetForeignClassMetadata>(address);
case MetadataKind::ForeignReferenceType:
return _readMetadata<TargetForeignReferenceTypeMetadata>(address);
case MetadataKind::Function: {
StoredSize flagsValue;
auto flagsAddr =
address + TargetFunctionTypeMetadata<Runtime>::OffsetToFlags;
if (!Reader->readInteger(flagsAddr, &flagsValue))
return nullptr;

auto flags =
TargetFunctionTypeFlags<StoredSize>::fromIntValue(flagsValue);

auto totalSize =
sizeof(TargetFunctionTypeMetadata<Runtime>) +
flags.getNumParameters() * sizeof(FunctionTypeMetadata::Parameter);

if (flags.hasParameterFlags())
totalSize += flags.getNumParameters() * sizeof(uint32_t);

if (flags.isDifferentiable())
totalSize = roundUpToAlignment(totalSize, sizeof(StoredPointer)) +
sizeof(TargetFunctionMetadataDifferentiabilityKind<
typename Runtime::StoredSize>);

return _readMetadata(address,
roundUpToAlignment(totalSize, sizeof(StoredPointer)));
}
return _readMetadataFixedSize<TargetForeignReferenceTypeMetadata>(
address);
case MetadataKind::Function:
return _readMetadataVariableSize<TargetFunctionTypeMetadata>(address);
case MetadataKind::HeapGenericLocalVariable:
return _readMetadata<TargetGenericBoxHeapMetadata>(address);
return _readMetadataFixedSize<TargetGenericBoxHeapMetadata>(address);
case MetadataKind::HeapLocalVariable:
return _readMetadata<TargetHeapLocalVariableMetadata>(address);
return _readMetadataFixedSize<TargetHeapLocalVariableMetadata>(address);
case MetadataKind::Metatype:
return _readMetadata<TargetMetatypeMetadata>(address);
return _readMetadataFixedSize<TargetMetatypeMetadata>(address);
case MetadataKind::ObjCClassWrapper:
return _readMetadata<TargetObjCClassWrapperMetadata>(address);
return _readMetadataFixedSize<TargetObjCClassWrapperMetadata>(address);
case MetadataKind::Optional:
return _readMetadata<TargetEnumMetadata>(address);
return _readMetadataFixedSize<TargetEnumMetadata>(address);
case MetadataKind::Struct:
return _readMetadata<TargetStructMetadata>(address);
return _readMetadataFixedSize<TargetStructMetadata>(address);
case MetadataKind::Tuple: {
auto numElementsAddress = address +
TargetTupleTypeMetadata<Runtime>::getOffsetToNumElements();
Expand All @@ -2255,7 +2220,7 @@ class MetadataReader {
}
case MetadataKind::Opaque:
default:
return _readMetadata<TargetOpaqueMetadata>(address);
return _readMetadataFixedSize<TargetOpaqueMetadata>(address);
}

// We can fall out here if the value wasn't actually a valid
Expand Down Expand Up @@ -2333,20 +2298,39 @@ class MetadataReader {

private:
template <template <class R> class M>
MetadataRef _readMetadata(RemoteAddress address) {
MetadataRef _readMetadataFixedSize(RemoteAddress address) {
static_assert(!ABI::typeHasTrailingObjects<M<Runtime>>(),
"Type must not have trailing objects. Use "
"_readMetadataVariableSize for types that have them.");

return _readMetadata(address, sizeof(M<Runtime>));
}

template <template <class R> class M>
MetadataRef _readMetadataVariableSize(RemoteAddress address) {
static_assert(ABI::typeHasTrailingObjects<M<Runtime>>(),
"Type must have trailing objects. Use _readMetadataFixedSize "
"for types that don't.");

MemoryReader::ReadBytesResult bytes;
auto readResult = readFullTrailingObjects<M<Runtime>>(address, bytes, 0);
if (!readResult)
return nullptr;
return _cacheMetadata(address, bytes);
}

MetadataRef _readMetadata(RemoteAddress address, size_t sizeAfter) {
if (sizeAfter > MaxMetadataSize)
return nullptr;
auto readResult = Reader->readBytes(address, sizeAfter);
if (!readResult)
return nullptr;
return _cacheMetadata(address, readResult);
}

MetadataRef _cacheMetadata(RemoteAddress address,
MemoryReader::ReadBytesResult &bytes) {
auto metadata =
reinterpret_cast<const TargetMetadata<Runtime> *>(readResult.get());
MetadataCache.insert(std::make_pair(address, std::move(readResult)));
reinterpret_cast<const TargetMetadata<Runtime> *>(bytes.get());
MetadataCache.insert(std::make_pair(address, std::move(bytes)));
return MetadataRef(address, metadata);
}

Expand Down
56 changes: 56 additions & 0 deletions validation-test/Reflection/function_types.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
// RUN: %empty-directory(%t)
// RUN: %target-build-swift -lswiftSwiftReflectionTest %s -o %t/function_types
// RUN: %target-codesign %t/function_types

// RUN: %target-run %target-swift-reflection-test %t/function_types | %FileCheck %s --check-prefix=CHECK --check-prefix=CHECK-%target-ptrsize %add_num_extra_inhabitants

// REQUIRES: reflection_test_support
// REQUIRES: executable_test
// UNSUPPORTED: use_os_stdlib
// UNSUPPORTED: asan

import SwiftReflectionTest

struct S {
var f = { @MainActor in }
}

// This Task is necessary to ensure that the concurrency runtime is brought in.
// Without that, the type lookup for @MainActor may fail.
Task {}

// CHECK: Type reference:
// CHECK: (struct function_types.S)

// CHECK: Type info:
// CHECK: (struct size=
// CHECK: (field name=f offset=0
// CHECK: (thick_function size=
// CHECK: (field name=function offset=0
// CHECK: (builtin size=
// CHECK: (field name=context
// CHECK: (reference kind=strong refcounting=native)))))
// CHECK: Mangled name: $s14function_types1SV
// CHECK: Demangled name: function_types.S
reflect(any: S())

// CHECK: Type reference:
// CHECK: (function
// CHECK: (global-actor
// CHECK: (class Swift.MainActor))
// CHECK: (parameters)
// CHECK: (result
// CHECK: (tuple))

// CHECK: Type info:
// CHECK: (thick_function size=
// CHECK: (field name=function offset=0
// CHECK: (builtin size=
// CHECK: (field name=context offset=
// CHECK: (reference kind=strong refcounting=native)))
// CHECK: Mangled name: $syyScMYcc
// CHECK: Demangled name: @Swift.MainActor () -> ()
reflect(any: S().f)

doneReflecting()