diff --git a/include/swift/ABI/TrailingObjects.h b/include/swift/ABI/TrailingObjects.h index 4ed23b7de6180..e9b64230ce4a9 100644 --- a/include/swift/ABI/TrailingObjects.h +++ b/include/swift/ABI/TrailingObjects.h @@ -421,6 +421,15 @@ class swift_ptrauth_struct_derived(BaseTy) TrailingObjects } }; +// Helper function to determine at build time if a type has TrailingObjects. +// This is useful for determining if trailingTypeCount and +// sizeWithTrailingTypeCount are available for code that reads TrailingObjects +// values in a generalized fashion. +template +static constexpr bool typeHasTrailingObjects() { + return std::is_base_of_v; +} + } // end namespace ABI } // end namespace swift diff --git a/include/swift/Remote/MetadataReader.h b/include/swift/Remote/MetadataReader.h index 770024ecc4e59..8ea3580a5e765 100644 --- a/include/swift/Remote/MetadataReader.h +++ b/include/swift/Remote/MetadataReader.h @@ -1493,33 +1493,33 @@ class MetadataReader { break; case ContextDescriptorKind::Extension: success = - readFullContextDescriptor>( - remoteAddress, ptr); + readFullTrailingObjects>( + remoteAddress, ptr, sizeof(TargetContextDescriptor)); break; case ContextDescriptorKind::Anonymous: success = - readFullContextDescriptor>( - remoteAddress, ptr); + readFullTrailingObjects>( + remoteAddress, ptr, sizeof(TargetContextDescriptor)); break; case ContextDescriptorKind::Class: - success = readFullContextDescriptor>( - remoteAddress, ptr); + success = readFullTrailingObjects>( + remoteAddress, ptr, sizeof(TargetContextDescriptor)); break; case ContextDescriptorKind::Enum: - success = readFullContextDescriptor>( - remoteAddress, ptr); + success = readFullTrailingObjects>( + remoteAddress, ptr, sizeof(TargetContextDescriptor)); break; case ContextDescriptorKind::Struct: - success = readFullContextDescriptor>( - remoteAddress, ptr); + success = readFullTrailingObjects>( + remoteAddress, ptr, sizeof(TargetContextDescriptor)); break; case ContextDescriptorKind::Protocol: - success = readFullContextDescriptor>( - remoteAddress, ptr); + success = readFullTrailingObjects>( + remoteAddress, ptr, sizeof(TargetContextDescriptor)); break; case ContextDescriptorKind::OpaqueType: - success = readFullContextDescriptor>( - remoteAddress, ptr); + success = readFullTrailingObjects>( + remoteAddress, ptr, sizeof(TargetContextDescriptor)); break; default: // We don't know about this kind of context. @@ -1535,12 +1535,18 @@ class MetadataReader { return ContextDescriptorRef(remoteAddress, descriptor); } - template - bool readFullContextDescriptor(RemoteAddress address, - MemoryReader::ReadBytesResult &ptr) { + /// Read all memory occupied by a value with TrailingObjects. This will + /// incrementally read pieces of the object to figure out the full size of it. + /// - address: The address of the value. + /// - ptr: The bytes that have been read so far. On return, the full object. + /// - existingByteCount: The number of bytes in ptr. + template + bool readFullTrailingObjects(RemoteAddress address, + MemoryReader::ReadBytesResult &ptr, + size_t existingByteCount) { // Read the full base descriptor if it's bigger than what we have so far. - if (sizeof(DescriptorTy) > sizeof(TargetContextDescriptor)) { - ptr = Reader->template readObj(address); + if (sizeof(BaseTy) > existingByteCount) { + ptr = Reader->template readObj(address); if (!ptr) return false; } @@ -1556,13 +1562,17 @@ class MetadataReader { // size. Once we've walked through all the trailing objects, we've read // everything. - size_t sizeSoFar = sizeof(DescriptorTy); + size_t sizeSoFar = sizeof(BaseTy); - for (size_t i = 0; i < DescriptorTy::trailingTypeCount(); i++) { - const DescriptorTy *descriptorSoFar = - reinterpret_cast(ptr.get()); + for (size_t i = 0; i < BaseTy::trailingTypeCount(); i++) { + const BaseTy *descriptorSoFar = + reinterpret_cast(ptr.get()); size_t thisSize = descriptorSoFar->sizeWithTrailingTypeCount(i); if (thisSize > sizeSoFar) { + // Make sure we haven't ended up with a ridiculous size. + if (thisSize > MaxMetadataSize) + return false; + ptr = Reader->readBytes(address, thisSize); if (!ptr) return false; @@ -2141,45 +2151,22 @@ class MetadataReader { switch (getEnumeratedMetadataKind(KindValue)) { case MetadataKind::Class: - - return _readMetadata(address); - + return _readMetadataFixedSize(address); case MetadataKind::Enum: - return _readMetadata(address); + return _readMetadataFixedSize(address); case MetadataKind::ErrorObject: - return _readMetadata(address); - case MetadataKind::Existential: { - RemoteAddress flagsAddress = address + sizeof(StoredPointer); - - ExistentialTypeFlags::int_type flagsData; - if (!Reader->readInteger(flagsAddress, &flagsData)) - return nullptr; - - ExistentialTypeFlags flags(flagsData); - - RemoteAddress numProtocolsAddress = flagsAddress + sizeof(flagsData); - uint32_t numProtocols; - if (!Reader->readInteger(numProtocolsAddress, &numProtocols)) - return nullptr; - - // Make sure the number of protocols is reasonable - if (numProtocols >= 256) - return nullptr; - - auto totalSize = sizeof(TargetExistentialTypeMetadata) - + numProtocols * - sizeof(ConstTargetMetadataPointer); - - if (flags.hasSuperclassConstraint()) - totalSize += sizeof(StoredPointer); - - return _readMetadata(address, totalSize); - } + return _readMetadataFixedSize(address); + case MetadataKind::Existential: + return _readMetadataVariableSize( + address); case MetadataKind::ExistentialMetatype: - return _readMetadata(address); + return _readMetadataFixedSize( + address); case MetadataKind::ExtendedExistential: { // We need to read the shape in order to figure out how large - // the generalization arguments are. + // the generalization arguments are. This prevents us from using + // _readMetadataVariableSize, which requires the Shape field to be + // dereferenceable here. RemoteAddress shapeAddress = address + sizeof(StoredPointer); RemoteAddress signedShapePtr; if (!Reader->template readRemoteAddress(shapeAddress, @@ -2198,46 +2185,24 @@ class MetadataReader { return _readMetadata(address, totalSize); } case MetadataKind::ForeignClass: - return _readMetadata(address); + return _readMetadataFixedSize(address); case MetadataKind::ForeignReferenceType: - return _readMetadata(address); - case MetadataKind::Function: { - StoredSize flagsValue; - auto flagsAddr = - address + TargetFunctionTypeMetadata::OffsetToFlags; - if (!Reader->readInteger(flagsAddr, &flagsValue)) - return nullptr; - - auto flags = - TargetFunctionTypeFlags::fromIntValue(flagsValue); - - auto totalSize = - sizeof(TargetFunctionTypeMetadata) + - flags.getNumParameters() * sizeof(FunctionTypeMetadata::Parameter); - - if (flags.hasParameterFlags()) - totalSize += flags.getNumParameters() * sizeof(uint32_t); - - if (flags.isDifferentiable()) - totalSize = roundUpToAlignment(totalSize, sizeof(StoredPointer)) + - sizeof(TargetFunctionMetadataDifferentiabilityKind< - typename Runtime::StoredSize>); - - return _readMetadata(address, - roundUpToAlignment(totalSize, sizeof(StoredPointer))); - } + return _readMetadataFixedSize( + address); + case MetadataKind::Function: + return _readMetadataVariableSize(address); case MetadataKind::HeapGenericLocalVariable: - return _readMetadata(address); + return _readMetadataFixedSize(address); case MetadataKind::HeapLocalVariable: - return _readMetadata(address); + return _readMetadataFixedSize(address); case MetadataKind::Metatype: - return _readMetadata(address); + return _readMetadataFixedSize(address); case MetadataKind::ObjCClassWrapper: - return _readMetadata(address); + return _readMetadataFixedSize(address); case MetadataKind::Optional: - return _readMetadata(address); + return _readMetadataFixedSize(address); case MetadataKind::Struct: - return _readMetadata(address); + return _readMetadataFixedSize(address); case MetadataKind::Tuple: { auto numElementsAddress = address + TargetTupleTypeMetadata::getOffsetToNumElements(); @@ -2255,7 +2220,7 @@ class MetadataReader { } case MetadataKind::Opaque: default: - return _readMetadata(address); + return _readMetadataFixedSize(address); } // We can fall out here if the value wasn't actually a valid @@ -2333,20 +2298,39 @@ class MetadataReader { private: template