Skip to content

Commit db9052f

Browse files
malfetpobin6
authored andcommitted
[MPS] Allow >2**32 metal dispatches (pytorch#140862)
By passing length as `NSUInteger` which should be a 64-bit value on all 64-bit systems according to https://developer.apple.com/documentation/objectivec/nsuinteger?language=objc Pull Request resolved: pytorch#140862 Approved by: https://github.com/Skylion007
1 parent a4278e3 commit db9052f

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

aten/src/ATen/native/mps/OperationUtils.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -392,8 +392,9 @@ static inline void mtl_setBytes(id<MTLComputeCommandEncoder> encoder, const Cont
392392

393393
static inline void mtl_dispatch1DJob(id<MTLComputeCommandEncoder> encoder,
394394
id<MTLComputePipelineState> cplState,
395-
uint32_t length) {
396-
const uint32_t maxThreadsPerGroup = [cplState maxTotalThreadsPerThreadgroup];
395+
NSUInteger length) {
396+
static_assert(sizeof(NSUInteger) == sizeof(uint64_t));
397+
const auto maxThreadsPerGroup = [cplState maxTotalThreadsPerThreadgroup];
397398
auto size = MTLSizeMake(length, 1, 1);
398399
auto threadGroupSize = MTLSizeMake(std::min(maxThreadsPerGroup, length), 1, 1);
399400
[encoder dispatchThreads:size threadsPerThreadgroup:threadGroupSize];

0 commit comments

Comments
 (0)