diff --git a/compiler/map.go b/compiler/map.go index 6aaf8c76c6..9e380209a3 100644 --- a/compiler/map.go +++ b/compiler/map.go @@ -3,42 +3,28 @@ package compiler // This file emits the correct map intrinsics for map operations. import ( + "fmt" "go/token" "go/types" - - "github.com/tinygo-org/tinygo/src/tinygo" "golang.org/x/tools/go/ssa" "tinygo.org/x/go-llvm" ) +const hashArrayUnrollLimit = 4 + // createMakeMap creates a new map object (runtime.hashmap) by allocating and // initializing an appropriately sized object. func (b *builder) createMakeMap(expr *ssa.MakeMap) (llvm.Value, error) { mapType := expr.Type().Underlying().(*types.Map) keyType := mapType.Key().Underlying() llvmValueType := b.getLLVMType(mapType.Elem().Underlying()) - var llvmKeyType llvm.Type - var alg uint64 - if t, ok := keyType.(*types.Basic); ok && t.Info()&types.IsString != 0 { - // String keys. - llvmKeyType = b.getLLVMType(keyType) - alg = uint64(tinygo.HashmapAlgorithmString) - } else if hashmapIsBinaryKey(keyType) { - // Trivially comparable keys. - llvmKeyType = b.getLLVMType(keyType) - alg = uint64(tinygo.HashmapAlgorithmBinary) - } else { - // All other keys. Implemented as map[interface{}]valueType for ease of - // implementation. - llvmKeyType = b.getLLVMRuntimeType("_interface") - alg = uint64(tinygo.HashmapAlgorithmInterface) - } + llvmKeyType := b.getLLVMType(keyType) + keySize := b.targetData.TypeAllocSize(llvmKeyType) valueSize := b.targetData.TypeAllocSize(llvmValueType) llvmKeySize := llvm.ConstInt(b.uintptrType, keySize, false) llvmValueSize := llvm.ConstInt(b.uintptrType, valueSize, false) sizeHint := llvm.ConstInt(b.uintptrType, 8, false) - algEnum := llvm.ConstInt(b.ctx.Int8Type(), alg, false) if expr.Reserve != nil { sizeHint = b.getValue(expr.Reserve, getPos(expr)) var err error @@ -47,10 +33,42 @@ func (b *builder) createMakeMap(expr *ssa.MakeMap) (llvm.Value, error) { return llvm.Value{}, err } } - hashmap := b.createRuntimeCall("hashmapMake", []llvm.Value{llvmKeySize, llvmValueSize, sizeHint, algEnum}, "") + + // Resolve hash and equal functions for this key type. For string and + // binary key types, reference the corresponding runtime functions + // directly. For composite types, generate type-specific functions. + var hashFn, equalFn llvm.Value + if t, ok := keyType.(*types.Basic); ok && t.Info()&types.IsString != 0 { + hashFn = b.getRuntimeFunctionValue("hashmapStringPtrHash", hashmapKeyHashSignature()) + equalFn = b.getRuntimeFunctionValue("hashmapStringEqual", hashmapKeyEqualSignature()) + } else if hashmapIsBinaryKey(keyType) { + hashFn = b.getRuntimeFunctionValue("hash32", hashmapKeyHashSignature()) + equalFn = b.getRuntimeFunctionValue("memequal", hashmapKeyEqualSignature()) + } else { + fn := b.getOrGenerateKeyHashFunc(keyType) + hashFn = b.createFuncValue(fn, llvm.ConstNull(b.dataPtrType), hashmapKeyHashSignature()) + fn = b.getOrGenerateKeyEqualFunc(keyType) + equalFn = b.createFuncValue(fn, llvm.ConstNull(b.dataPtrType), hashmapKeyEqualSignature()) + } + + hashmap := b.createRuntimeCall("hashmapMakeGeneric", []llvm.Value{ + llvmKeySize, llvmValueSize, sizeHint, + hashFn, equalFn, + }, "") return hashmap, nil } +// getRuntimeFunctionValue returns a TinyGo function value (with nil context) +// for the named runtime function. +func (b *builder) getRuntimeFunctionValue(name string, sig *types.Signature) llvm.Value { + member := b.program.ImportedPackage("runtime").Members[name] + if member == nil { + panic("unknown runtime function: " + name) + } + _, llvmFn := b.getFunction(member.(*ssa.Function)) + return b.createFuncValue(llvmFn, llvm.ConstNull(b.dataPtrType), sig) +} + // createMapLookup returns the value in a map. It calls a runtime function // depending on the map key type to load the map value and its comma-ok value. func (b *builder) createMapLookup(keyType, valueType types.Type, m, key llvm.Value, commaOk bool, pos token.Pos) (llvm.Value, error) { @@ -72,32 +90,23 @@ func (b *builder) createMapLookup(keyType, valueType types.Type, m, key llvm.Val // Do the lookup. How it is done depends on the key type. var commaOkValue llvm.Value - origKeyType := keyType keyType = keyType.Underlying() if t, ok := keyType.(*types.Basic); ok && t.Info()&types.IsString != 0 { // key is a string params := []llvm.Value{m, key, mapValueAlloca, mapValueSize} commaOkValue = b.createRuntimeCall("hashmapStringGet", params, "") - } else if hashmapIsBinaryKey(keyType) { - // key can be compared with runtime.memequal - // Store the key in an alloca, in the entry block to avoid dynamic stack - // growth. + } else { + // Key stored at actual type: either binary-comparable or with + // compiler-generated hash/equal. mapKeyAlloca, mapKeySize := b.createTemporaryAlloca(key.Type(), "hashmap.key") b.CreateStore(key, mapKeyAlloca) - b.zeroUndefBytes(b.getLLVMType(keyType), mapKeyAlloca) - // Fetch the value from the hashmap. params := []llvm.Value{m, mapKeyAlloca, mapValueAlloca, mapValueSize} - commaOkValue = b.createRuntimeCall("hashmapBinaryGet", params, "") - b.emitLifetimeEnd(mapKeyAlloca, mapKeySize) - } else { - // Not trivially comparable using memcmp. Make it an interface instead. - itfKey := key - if _, ok := keyType.(*types.Interface); !ok { - // Not already an interface, so convert it to an interface now. - itfKey = b.createMakeInterface(key, origKeyType, pos) + fnName := "hashmapBinaryGet" + if !hashmapIsBinaryKey(keyType) { + fnName = "hashmapGenericGet" } - params := []llvm.Value{m, itfKey, mapValueAlloca, mapValueSize} - commaOkValue = b.createRuntimeCall("hashmapInterfaceGet", params, "") + commaOkValue = b.createRuntimeCall(fnName, params, "") + b.emitLifetimeEnd(mapKeyAlloca, mapKeySize) } // Load the resulting value from the hashmap. The value is set to the zero @@ -120,29 +129,22 @@ func (b *builder) createMapLookup(keyType, valueType types.Type, m, key llvm.Val func (b *builder) createMapUpdate(keyType types.Type, m, key, value llvm.Value, pos token.Pos) { valueAlloca, valueSize := b.createTemporaryAlloca(value.Type(), "hashmap.value") b.CreateStore(value, valueAlloca) - origKeyType := keyType keyType = keyType.Underlying() if t, ok := keyType.(*types.Basic); ok && t.Info()&types.IsString != 0 { // key is a string params := []llvm.Value{m, key, valueAlloca} b.createRuntimeCall("hashmapStringSet", params, "") - } else if hashmapIsBinaryKey(keyType) { - // key can be compared with runtime.memequal + } else { + // Key stored at actual type. keyAlloca, keySize := b.createTemporaryAlloca(key.Type(), "hashmap.key") b.CreateStore(key, keyAlloca) - b.zeroUndefBytes(b.getLLVMType(keyType), keyAlloca) + fnName := "hashmapBinarySet" + if !hashmapIsBinaryKey(keyType) { + fnName = "hashmapGenericSet" + } params := []llvm.Value{m, keyAlloca, valueAlloca} - b.createRuntimeCall("hashmapBinarySet", params, "") + b.createRuntimeCall(fnName, params, "") b.emitLifetimeEnd(keyAlloca, keySize) - } else { - // Key is not trivially comparable, so compare it as an interface instead. - itfKey := key - if _, ok := keyType.(*types.Interface); !ok { - // Not already an interface, so convert it to an interface first. - itfKey = b.createMakeInterface(key, origKeyType, pos) - } - params := []llvm.Value{m, itfKey, valueAlloca} - b.createRuntimeCall("hashmapInterfaceSet", params, "") } b.emitLifetimeEnd(valueAlloca, valueSize) } @@ -150,32 +152,24 @@ func (b *builder) createMapUpdate(keyType types.Type, m, key, value llvm.Value, // createMapDelete deletes a key from a map by calling the appropriate runtime // function. It is the implementation of the Go delete() builtin. func (b *builder) createMapDelete(keyType types.Type, m, key llvm.Value, pos token.Pos) error { - origKeyType := keyType keyType = keyType.Underlying() if t, ok := keyType.(*types.Basic); ok && t.Info()&types.IsString != 0 { // key is a string params := []llvm.Value{m, key} b.createRuntimeCall("hashmapStringDelete", params, "") return nil - } else if hashmapIsBinaryKey(keyType) { + } else { + // Key stored at actual type. keyAlloca, keySize := b.createTemporaryAlloca(key.Type(), "hashmap.key") b.CreateStore(key, keyAlloca) - b.zeroUndefBytes(b.getLLVMType(keyType), keyAlloca) + fnName := "hashmapBinaryDelete" + if !hashmapIsBinaryKey(keyType) { + fnName = "hashmapGenericDelete" + } params := []llvm.Value{m, keyAlloca} - b.createRuntimeCall("hashmapBinaryDelete", params, "") + b.createRuntimeCall(fnName, params, "") b.emitLifetimeEnd(keyAlloca, keySize) return nil - } else { - // Key is not trivially comparable, so compare it as an interface - // instead. - itfKey := key - if _, ok := keyType.(*types.Interface); !ok { - // Not already an interface, so convert it to an interface first. - itfKey = b.createMakeInterface(key, origKeyType, pos) - } - params := []llvm.Value{m, itfKey} - b.createRuntimeCall("hashmapInterfaceDelete", params, "") - return nil } } @@ -195,42 +189,15 @@ func (b *builder) createMapIteratorNext(rangeVal ssa.Value, llvmRangeVal, it llv llvmKeyType := b.getLLVMType(keyType) llvmValueType := b.getLLVMType(valueType) - // There is a special case in which keys are stored as an interface value - // instead of the value they normally are. This happens for non-trivially - // comparable types such as float32 or some structs. - isKeyStoredAsInterface := false - if t, ok := keyType.Underlying().(*types.Basic); ok && t.Info()&types.IsString != 0 { - // key is a string - } else if hashmapIsBinaryKey(keyType) { - // key can be compared with runtime.memequal - } else { - // The key is stored as an interface value, and may or may not be an - // interface type (for example, float32 keys are stored as an interface - // value). - if _, ok := keyType.Underlying().(*types.Interface); !ok { - isKeyStoredAsInterface = true - } - } - - // Determine the type of the key as stored in the map. - llvmStoredKeyType := llvmKeyType - if isKeyStoredAsInterface { - llvmStoredKeyType = b.getLLVMRuntimeType("_interface") - } + // All key types are now stored at their declared type (no interface wrapping). // Extract the key and value from the map. - mapKeyAlloca, mapKeySize := b.createTemporaryAlloca(llvmStoredKeyType, "range.key") + mapKeyAlloca, mapKeySize := b.createTemporaryAlloca(llvmKeyType, "range.key") mapValueAlloca, mapValueSize := b.createTemporaryAlloca(llvmValueType, "range.value") ok := b.createRuntimeCall("hashmapNext", []llvm.Value{llvmRangeVal, it, mapKeyAlloca, mapValueAlloca}, "range.next") - mapKey := b.CreateLoad(llvmStoredKeyType, mapKeyAlloca, "") + mapKey := b.CreateLoad(llvmKeyType, mapKeyAlloca, "") mapValue := b.CreateLoad(llvmValueType, mapValueAlloca, "") - if isKeyStoredAsInterface { - // The key is stored as an interface but it isn't of interface type. - // Extract the underlying value. - mapKey = b.extractValueFromInterface(mapKey, llvmKeyType) - } - // End the lifetimes of the allocas, because we're done with them. b.emitLifetimeEnd(mapKeyAlloca, mapKeySize) b.emitLifetimeEnd(mapValueAlloca, mapValueSize) @@ -250,20 +217,9 @@ func (b *builder) createMapIteratorNext(rangeVal ssa.Value, llvmRangeVal, it llv func hashmapIsBinaryKey(keyType types.Type) bool { switch keyType := keyType.Underlying().(type) { case *types.Basic: - // TODO: unsafe.Pointer is also a binary key, but to support that we - // need to fix an issue with interp first (see - // https://github.com/tinygo-org/tinygo/pull/4898). - return keyType.Info()&(types.IsBoolean|types.IsInteger) != 0 + return keyType.Info()&(types.IsBoolean|types.IsInteger) != 0 || keyType.Kind() == types.UnsafePointer case *types.Pointer: return true - case *types.Struct: - for i := 0; i < keyType.NumFields(); i++ { - fieldType := keyType.Field(i).Type().Underlying() - if !hashmapIsBinaryKey(fieldType) { - return false - } - } - return true case *types.Array: return hashmapIsBinaryKey(keyType.Elem()) default: @@ -271,68 +227,400 @@ func hashmapIsBinaryKey(keyType types.Type) bool { } } -func (b *builder) zeroUndefBytes(llvmType llvm.Type, ptr llvm.Value) error { - // We know that hashmapIsBinaryKey is true, so we only have to handle those types that can show up there. - // To zero all undefined bytes, we iterate over all the fields in the type. For each element, compute the - // offset of that element. If it's Basic type, there are no internal padding bytes. For compound types, we recurse to ensure - // we handle nested types. Next, we determine if there are any padding bytes before the next - // element and zero those as well. - - zero := llvm.ConstInt(b.ctx.Int32Type(), 0, false) +// hashmapKeyHashSignature returns the Go type signature for hashmap key hash +// functions: func(key unsafe.Pointer, size, seed uintptr) uint32 +func hashmapKeyHashSignature() *types.Signature { + return types.NewSignatureType(nil, nil, nil, + types.NewTuple( + types.NewVar(token.NoPos, nil, "key", types.Typ[types.UnsafePointer]), + types.NewVar(token.NoPos, nil, "size", types.Typ[types.Uintptr]), + types.NewVar(token.NoPos, nil, "seed", types.Typ[types.Uintptr]), + ), + types.NewTuple( + types.NewVar(token.NoPos, nil, "", types.Typ[types.Uint32]), + ), + false, + ) +} - switch llvmType.TypeKind() { - case llvm.IntegerTypeKind: - // no padding bytes - return nil - case llvm.PointerTypeKind: - // mo padding bytes - return nil - case llvm.ArrayTypeKind: - llvmArrayType := llvmType - llvmElemType := llvmType.ElementType() +// hashmapKeyEqualSignature returns the Go type signature for hashmap key equal +// functions: func(x, y unsafe.Pointer, n uintptr) bool +func hashmapKeyEqualSignature() *types.Signature { + return types.NewSignatureType(nil, nil, nil, + types.NewTuple( + types.NewVar(token.NoPos, nil, "x", types.Typ[types.UnsafePointer]), + types.NewVar(token.NoPos, nil, "y", types.Typ[types.UnsafePointer]), + types.NewVar(token.NoPos, nil, "n", types.Typ[types.Uintptr]), + ), + types.NewTuple( + types.NewVar(token.NoPos, nil, "", types.Typ[types.Bool]), + ), + false, + ) +} - for i := 0; i < llvmArrayType.ArrayLength(); i++ { - idx := llvm.ConstInt(b.uintptrType, uint64(i), false) - elemPtr := b.CreateInBoundsGEP(llvmArrayType, ptr, []llvm.Value{zero, idx}, "") +// hashmapKeyFuncName returns a canonical name for a generated hash or equal +// function based on the key type's underlying structure. Named types are +// replaced with their underlying types so that structurally identical key +// types (e.g., struct{i1; str1} and struct{i2; str2} where both i1, i2 are +// int and str1, str2 are string) share the same generated function. +func hashmapKeyFuncName(prefix string, keyType types.Type) string { + return prefix + "." + hashmapCanonicalTypeName(keyType) +} - // zero any padding bytes in this element - b.zeroUndefBytes(llvmElemType, elemPtr) +// hashmapCanonicalTypeName returns a string representation of the hash/equal +// operations needed for a type, stripping named types where the operation does +// not depend on the name. Pointer and channel names do not include the element +// type because their hash/equal operations only use the pointer word. +func hashmapCanonicalTypeName(t types.Type) string { + switch t := t.Underlying().(type) { + case *types.Basic: + return t.Name() + case *types.Pointer: + return "*" + case *types.Chan: + switch t.Dir() { + case types.SendRecv: + return "chan" + case types.SendOnly: + return "chan<-" + case types.RecvOnly: + return "<-chan" + } + case *types.Interface: + if t.NumMethods() == 0 { + return "interface{}" + } + return t.String() + case *types.Struct: + s := "struct{" + for i := 0; i < t.NumFields(); i++ { + if i > 0 { + s += "; " + } + s += hashmapCanonicalTypeName(t.Field(i).Type()) } + return s + "}" + case *types.Array: + return fmt.Sprintf("[%d]%s", t.Len(), hashmapCanonicalTypeName(t.Elem())) + } + return t.String() +} - case llvm.StructTypeKind: - llvmStructType := llvmType - numFields := llvmStructType.StructElementTypesCount() - llvmElementTypes := llvmStructType.StructElementTypes() +// getOrGenerateKeyHashFunc returns an LLVM function that computes the hash +// of a key of the given type. The function is generated on first call and +// cached in the module. +func (b *builder) getOrGenerateKeyHashFunc(keyType types.Type) llvm.Value { + name := hashmapKeyFuncName("hashmapKeyHash", keyType) + if fn := b.mod.NamedFunction(name); !fn.IsNil() { + return fn + } - for i := 0; i < numFields; i++ { - idx := llvm.ConstInt(b.ctx.Int32Type(), uint64(i), false) - elemPtr := b.CreateInBoundsGEP(llvmStructType, ptr, []llvm.Value{zero, idx}, "") + // Create the LLVM function type: + // (key ptr, size uintptr, seed uintptr, context ptr) -> i32 + fnType := llvm.FunctionType(b.ctx.Int32Type(), []llvm.Type{ + b.dataPtrType, b.uintptrType, b.uintptrType, b.dataPtrType, + }, false) + fn := llvm.AddFunction(b.mod, name, fnType) + fn.SetLinkage(llvm.LinkOnceODRLinkage) + fn.SetUnnamedAddr(true) + b.addStandardAttributes(fn) + + // Generate the function body. + savedBlock := b.GetInsertBlock() + defer b.SetInsertPointAtEnd(savedBlock) + + entry := b.ctx.AddBasicBlock(fn, "entry") + b.SetInsertPointAtEnd(entry) + + keyPtr := fn.Param(0) + seed := fn.Param(2) + llvmKeyType := b.getLLVMType(keyType) + hash := b.generateKeyHash(keyType, llvmKeyType, keyPtr, seed) + b.CreateRet(hash) - // zero any padding bytes in this field - llvmElemType := llvmElementTypes[i] - b.zeroUndefBytes(llvmElemType, elemPtr) + return fn +} - // zero any padding bytes before the next field, if any - offset := b.targetData.ElementOffset(llvmStructType, i) - storeSize := b.targetData.TypeStoreSize(llvmElemType) - fieldEndOffset := offset + storeSize +// getOrGenerateKeyEqualFunc returns an LLVM function that compares two keys +// of the given type for equality. The function is generated on first call +// and cached in the module. +func (b *builder) getOrGenerateKeyEqualFunc(keyType types.Type) llvm.Value { + name := hashmapKeyFuncName("hashmapKeyEqual", keyType) + if fn := b.mod.NamedFunction(name); !fn.IsNil() { + return fn + } - var nextOffset uint64 - if i < numFields-1 { - nextOffset = b.targetData.ElementOffset(llvmStructType, i+1) - } else { - // Last field? Next offset is the total size of the allocate struct. - nextOffset = b.targetData.TypeAllocSize(llvmStructType) - } + // Create the LLVM function type: + // (x ptr, y ptr, n uintptr, context ptr) -> i1 + fnType := llvm.FunctionType(b.ctx.Int1Type(), []llvm.Type{ + b.dataPtrType, b.dataPtrType, b.uintptrType, b.dataPtrType, + }, false) + fn := llvm.AddFunction(b.mod, name, fnType) + fn.SetLinkage(llvm.LinkOnceODRLinkage) + fn.SetUnnamedAddr(true) + b.addStandardAttributes(fn) + + // Generate the function body. + savedBlock := b.GetInsertBlock() + defer b.SetInsertPointAtEnd(savedBlock) + + entry := b.ctx.AddBasicBlock(fn, "entry") + b.SetInsertPointAtEnd(entry) + + xPtr := fn.Param(0) + yPtr := fn.Param(1) + llvmKeyType := b.getLLVMType(keyType) + result := b.generateKeyEqual(keyType, llvmKeyType, xPtr, yPtr, fn) + b.CreateRet(result) - if fieldEndOffset != nextOffset { - n := llvm.ConstInt(b.uintptrType, nextOffset-fieldEndOffset, false) - llvmStoreSize := llvm.ConstInt(b.uintptrType, storeSize, false) - paddingStart := b.CreateInBoundsGEP(b.ctx.Int8Type(), elemPtr, []llvm.Value{llvmStoreSize}, "") - b.createRuntimeCall("memzero", []llvm.Value{paddingStart, n}, "") + return fn +} + +// generateKeyHash generates IR that hashes a key value. Returns the i32 hash. +func (b *builder) generateKeyHash(keyType types.Type, llvmKeyType llvm.Type, keyPtr llvm.Value, seed llvm.Value) llvm.Value { + switch keyType := keyType.Underlying().(type) { + case *types.Basic: + if keyType.Info()&types.IsString != 0 { + // Hash the string contents. The size parameter is unused by + // hashmapStringPtrHash (it dereferences the string header to + // get the actual length), but we pass it for signature + // consistency with other hash functions. + size := llvm.ConstInt(b.uintptrType, b.targetData.TypeAllocSize(llvmKeyType), false) + return b.createRuntimeCall("hashmapStringPtrHash", []llvm.Value{keyPtr, size, seed}, "hash") + } + if keyType.Info()&types.IsFloat != 0 { + // Float hash: normalizes -0 to +0 before hashing. + if keyType.Kind() == types.Float32 { + return b.createRuntimeCall("hashmapFloat32Hash", []llvm.Value{keyPtr, seed}, "hash") + } + return b.createRuntimeCall("hashmapFloat64Hash", []llvm.Value{keyPtr, seed}, "hash") + } + if keyType.Info()&types.IsComplex != 0 { + // Complex hash: hash real and imaginary parts as floats. + if keyType.Kind() == types.Complex64 { + realPtr := keyPtr + imagPtr := b.CreateInBoundsGEP(b.ctx.Int8Type(), keyPtr, []llvm.Value{ + llvm.ConstInt(b.uintptrType, 4, false), + }, "") + realHash := b.createRuntimeCall("hashmapFloat32Hash", []llvm.Value{realPtr, seed}, "hash.real") + imagHash := b.createRuntimeCall("hashmapFloat32Hash", []llvm.Value{imagPtr, seed}, "hash.imag") + return b.CreateXor(realHash, imagHash, "") + } + realPtr := keyPtr + imagPtr := b.CreateInBoundsGEP(b.ctx.Int8Type(), keyPtr, []llvm.Value{ + llvm.ConstInt(b.uintptrType, 8, false), + }, "") + realHash := b.createRuntimeCall("hashmapFloat64Hash", []llvm.Value{realPtr, seed}, "hash.real") + imagHash := b.createRuntimeCall("hashmapFloat64Hash", []llvm.Value{imagPtr, seed}, "hash.imag") + return b.CreateXor(realHash, imagHash, "") + } + // Integer/boolean: hash the raw bytes. + size := llvm.ConstInt(b.uintptrType, b.targetData.TypeAllocSize(llvmKeyType), false) + return b.createRuntimeCall("hash32", []llvm.Value{keyPtr, size, seed}, "hash") + case *types.Pointer, *types.Chan: + // Pointers and channels: hash as raw pointer-sized bytes. + size := llvm.ConstInt(b.uintptrType, b.targetData.TypeAllocSize(llvmKeyType), false) + return b.createRuntimeCall("hash32", []llvm.Value{keyPtr, size, seed}, "hash") + case *types.Interface: + // Interface: use runtime reflection-based hash. + size := llvm.ConstInt(b.uintptrType, b.targetData.TypeAllocSize(llvmKeyType), false) + return b.createRuntimeCall("hashmapInterfacePtrHash", []llvm.Value{keyPtr, size, seed}, "hash") + case *types.Struct: + hash := llvm.ConstInt(b.ctx.Int32Type(), 0, false) + zero := llvm.ConstInt(b.ctx.Int32Type(), 0, false) + for i := 0; i < keyType.NumFields(); i++ { + if keyType.Field(i).Name() == "_" { + continue // blank fields are ignored in Go equality + } + fieldType := keyType.Field(i).Type() + llvmFieldType := b.getLLVMType(fieldType) + if b.targetData.TypeAllocSize(llvmFieldType) == 0 { + continue // skip zero-sized fields + } + idx := llvm.ConstInt(b.ctx.Int32Type(), uint64(i), false) + fieldPtr := b.CreateInBoundsGEP(llvmKeyType, keyPtr, []llvm.Value{zero, idx}, "") + fieldHash := b.generateKeyHash(fieldType, llvmFieldType, fieldPtr, seed) + hash = b.CreateXor(hash, fieldHash, "") + } + return hash + case *types.Array: + elemType := keyType.Elem() + llvmElemType := b.getLLVMType(elemType) + arrayLen := keyType.Len() + if hashmapIsBinaryKey(elemType) { + // All elements are binary-comparable; hash the entire array as raw bytes. + size := llvm.ConstInt(b.uintptrType, b.targetData.TypeAllocSize(llvmKeyType), false) + return b.createRuntimeCall("hash32", []llvm.Value{keyPtr, size, seed}, "hash") + } + if arrayLen == 0 { + return llvm.ConstInt(b.ctx.Int32Type(), 0, false) + } + if arrayLen <= hashArrayUnrollLimit { + hash := llvm.ConstInt(b.ctx.Int32Type(), 0, false) + zero := llvm.ConstInt(b.ctx.Int32Type(), 0, false) + for i := 0; i < int(arrayLen); i++ { + idx := llvm.ConstInt(b.uintptrType, uint64(i), false) + elemPtr := b.CreateInBoundsGEP(llvmKeyType, keyPtr, []llvm.Value{zero, idx}, "") + elemHash := b.generateKeyHash(elemType, llvmElemType, elemPtr, seed) + hash = b.CreateXor(hash, elemHash, "") } + return hash } + initHash := llvm.ConstInt(b.ctx.Int32Type(), 0, false) + zero := llvm.ConstInt(b.ctx.Int32Type(), 0, false) + + loopEntry := b.GetInsertBlock() + loopBody := b.ctx.AddBasicBlock(loopEntry.Parent(), "hash.array.body") + loopDone := b.ctx.AddBasicBlock(loopEntry.Parent(), "hash.array.done") + + b.CreateBr(loopBody) + b.SetInsertPointAtEnd(loopBody) + + phiI := b.CreatePHI(b.uintptrType, "i") + phiHash := b.CreatePHI(b.ctx.Int32Type(), "hash.acc") + + idx := b.CreateTrunc(phiI, b.ctx.Int32Type(), "") + elemPtr := b.CreateInBoundsGEP(llvmKeyType, keyPtr, []llvm.Value{zero, idx}, "") + elemHash := b.generateKeyHash(elemType, llvmElemType, elemPtr, seed) + newHash := b.CreateXor(phiHash, elemHash, "") + nextI := b.CreateAdd(phiI, llvm.ConstInt(b.uintptrType, 1, false), "") + cond := b.CreateICmp(llvm.IntULT, nextI, llvm.ConstInt(b.uintptrType, uint64(arrayLen), false), "") + b.CreateCondBr(cond, loopBody, loopDone) + + phiI.AddIncoming([]llvm.Value{llvm.ConstInt(b.uintptrType, 0, false), nextI}, + []llvm.BasicBlock{loopEntry, loopBody}) + phiHash.AddIncoming([]llvm.Value{initHash, newHash}, + []llvm.BasicBlock{loopEntry, loopBody}) + + b.SetInsertPointAtEnd(loopDone) + return newHash + default: + panic(fmt.Sprintf("unhandled key type for hash generation: %T", keyType)) } +} - return nil +// generateKeyEqual generates IR that compares two key values for equality. +// Returns an i1 result. +func (b *builder) generateKeyEqual(keyType types.Type, llvmKeyType llvm.Type, xPtr, yPtr llvm.Value, fn llvm.Value) llvm.Value { + switch keyType := keyType.Underlying().(type) { + case *types.Basic: + if keyType.Info()&types.IsString != 0 { + // Compare strings: load both string headers and compare. + xStr := b.CreateLoad(llvmKeyType, xPtr, "x.str") + yStr := b.CreateLoad(llvmKeyType, yPtr, "y.str") + return b.createRuntimeCall("stringEqual", []llvm.Value{xStr, yStr}, "eq") + } + if keyType.Info()&types.IsFloat != 0 { + // Float equality: fcmp oeq handles -0==+0 (true) and NaN==NaN (false). + xVal := b.CreateLoad(llvmKeyType, xPtr, "x.float") + yVal := b.CreateLoad(llvmKeyType, yPtr, "y.float") + return b.CreateFCmp(llvm.FloatOEQ, xVal, yVal, "eq") + } + if keyType.Info()&types.IsComplex != 0 { + // Complex equality: both real and imaginary parts must be equal. + var floatType llvm.Type + if keyType.Kind() == types.Complex64 { + floatType = b.ctx.FloatType() + } else { + floatType = b.ctx.DoubleType() + } + floatSize := b.targetData.TypeAllocSize(floatType) + imagOffset := llvm.ConstInt(b.uintptrType, floatSize, false) + // Real parts + xReal := b.CreateLoad(floatType, xPtr, "x.real") + yReal := b.CreateLoad(floatType, yPtr, "y.real") + realEq := b.CreateFCmp(llvm.FloatOEQ, xReal, yReal, "eq.real") + // Imaginary parts + xImagPtr := b.CreateInBoundsGEP(b.ctx.Int8Type(), xPtr, []llvm.Value{imagOffset}, "") + yImagPtr := b.CreateInBoundsGEP(b.ctx.Int8Type(), yPtr, []llvm.Value{imagOffset}, "") + xImag := b.CreateLoad(floatType, xImagPtr, "x.imag") + yImag := b.CreateLoad(floatType, yImagPtr, "y.imag") + imagEq := b.CreateFCmp(llvm.FloatOEQ, xImag, yImag, "eq.imag") + return b.CreateAnd(realEq, imagEq, "") + } + // Integer/boolean: compare raw bytes. + size := llvm.ConstInt(b.uintptrType, b.targetData.TypeAllocSize(llvmKeyType), false) + return b.createRuntimeCall("memequal", []llvm.Value{xPtr, yPtr, size}, "eq") + case *types.Pointer, *types.Chan: + // Pointers and channels: compare as raw pointer-sized bytes. + size := llvm.ConstInt(b.uintptrType, b.targetData.TypeAllocSize(llvmKeyType), false) + return b.createRuntimeCall("memequal", []llvm.Value{xPtr, yPtr, size}, "eq") + case *types.Interface: + // Interface: use runtime interface equality. + size := llvm.ConstInt(b.uintptrType, b.targetData.TypeAllocSize(llvmKeyType), false) + return b.createRuntimeCall("hashmapInterfaceEqual", []llvm.Value{xPtr, yPtr, size}, "eq") + case *types.Struct: + result := llvm.ConstInt(b.ctx.Int1Type(), 1, false) // start with true + zero := llvm.ConstInt(b.ctx.Int32Type(), 0, false) + for i := 0; i < keyType.NumFields(); i++ { + if keyType.Field(i).Name() == "_" { + continue // blank fields are ignored in Go equality + } + fieldType := keyType.Field(i).Type() + llvmFieldType := b.getLLVMType(fieldType) + if b.targetData.TypeAllocSize(llvmFieldType) == 0 { + continue // skip zero-sized fields + } + idx := llvm.ConstInt(b.ctx.Int32Type(), uint64(i), false) + xFieldPtr := b.CreateInBoundsGEP(llvmKeyType, xPtr, []llvm.Value{zero, idx}, "") + yFieldPtr := b.CreateInBoundsGEP(llvmKeyType, yPtr, []llvm.Value{zero, idx}, "") + fieldEq := b.generateKeyEqual(fieldType, llvmFieldType, xFieldPtr, yFieldPtr, fn) + result = b.CreateAnd(result, fieldEq, "") + } + return result + case *types.Array: + elemType := keyType.Elem() + llvmElemType := b.getLLVMType(elemType) + arrayLen := keyType.Len() + if hashmapIsBinaryKey(elemType) { + // All elements are binary-comparable; compare the entire array. + size := llvm.ConstInt(b.uintptrType, b.targetData.TypeAllocSize(llvmKeyType), false) + return b.createRuntimeCall("memequal", []llvm.Value{xPtr, yPtr, size}, "eq") + } + if arrayLen == 0 { + return llvm.ConstInt(b.ctx.Int1Type(), 1, false) + } + if arrayLen <= hashArrayUnrollLimit { + result := llvm.ConstInt(b.ctx.Int1Type(), 1, false) + zero := llvm.ConstInt(b.ctx.Int32Type(), 0, false) + for i := 0; i < int(arrayLen); i++ { + idx := llvm.ConstInt(b.uintptrType, uint64(i), false) + xElemPtr := b.CreateInBoundsGEP(llvmKeyType, xPtr, []llvm.Value{zero, idx}, "") + yElemPtr := b.CreateInBoundsGEP(llvmKeyType, yPtr, []llvm.Value{zero, idx}, "") + elemEq := b.generateKeyEqual(elemType, llvmElemType, xElemPtr, yElemPtr, fn) + result = b.CreateAnd(result, elemEq, "") + } + return result + } + zero := llvm.ConstInt(b.ctx.Int32Type(), 0, false) + + loopEntry := b.GetInsertBlock() + loopBody := b.ctx.AddBasicBlock(loopEntry.Parent(), "eq.array.body") + loopDone := b.ctx.AddBasicBlock(loopEntry.Parent(), "eq.array.done") + + b.CreateBr(loopBody) + b.SetInsertPointAtEnd(loopBody) + + phiI := b.CreatePHI(b.uintptrType, "i") + + idx := b.CreateTrunc(phiI, b.ctx.Int32Type(), "") + xElemPtr := b.CreateInBoundsGEP(llvmKeyType, xPtr, []llvm.Value{zero, idx}, "") + yElemPtr := b.CreateInBoundsGEP(llvmKeyType, yPtr, []llvm.Value{zero, idx}, "") + elemEq := b.generateKeyEqual(elemType, llvmElemType, xElemPtr, yElemPtr, fn) + + nextI := b.CreateAdd(phiI, llvm.ConstInt(b.uintptrType, 1, false), "") + atEnd := b.CreateICmp(llvm.IntUGE, nextI, llvm.ConstInt(b.uintptrType, uint64(arrayLen), false), "") + exitLoop := b.CreateOr(atEnd, b.CreateNot(elemEq, ""), "") + b.CreateCondBr(exitLoop, loopDone, loopBody) + + bodyEnd := b.GetInsertBlock() + phiI.AddIncoming([]llvm.Value{llvm.ConstInt(b.uintptrType, 0, false), nextI}, + []llvm.BasicBlock{loopEntry, bodyEnd}) + + b.SetInsertPointAtEnd(loopDone) + return elemEq + default: + panic(fmt.Sprintf("unhandled key type for equal generation: %T", keyType)) + } } diff --git a/compiler/symbol.go b/compiler/symbol.go index 4f24ddbfc3..16c5433cfa 100644 --- a/compiler/symbol.go +++ b/compiler/symbol.go @@ -190,6 +190,31 @@ func (c *compilerContext) getFunction(fn *ssa.Function) (llvm.Type, llvm.Value) case "runtime.stringFromRunes": llvmFn.AddAttributeAtIndex(1, c.ctx.CreateEnumAttribute(llvm.AttributeKindID("nocapture"), 0)) llvmFn.AddAttributeAtIndex(1, c.ctx.CreateEnumAttribute(llvm.AttributeKindID("readonly"), 0)) + case "runtime.hashmapSet": + // The key (param 2) and value (param 3) pointers are only read via + // memcpy/hash/equal and are never captured. The indirect calls + // through m.keyHash and m.keyEqual function pointers prevent LLVM's + // functionattrs pass from inferring this automatically. + llvmFn.AddAttributeAtIndex(2, c.ctx.CreateEnumAttribute(llvm.AttributeKindID("nocapture"), 0)) + llvmFn.AddAttributeAtIndex(3, c.ctx.CreateEnumAttribute(llvm.AttributeKindID("nocapture"), 0)) + case "runtime.hashmapGet": + // The key (param 2) is read-only and never captured. + // The value (param 3) is written to (receives the result) but never captured. + llvmFn.AddAttributeAtIndex(2, c.ctx.CreateEnumAttribute(llvm.AttributeKindID("nocapture"), 0)) + llvmFn.AddAttributeAtIndex(3, c.ctx.CreateEnumAttribute(llvm.AttributeKindID("nocapture"), 0)) + case "runtime.hashmapDelete": + // The key (param 2) is read-only and never captured. + llvmFn.AddAttributeAtIndex(2, c.ctx.CreateEnumAttribute(llvm.AttributeKindID("nocapture"), 0)) + case "runtime.hashmapGenericSet": + // Same as hashmapBinarySet: key (param 2) and value (param 3) are + // not captured. + llvmFn.AddAttributeAtIndex(2, c.ctx.CreateEnumAttribute(llvm.AttributeKindID("nocapture"), 0)) + llvmFn.AddAttributeAtIndex(3, c.ctx.CreateEnumAttribute(llvm.AttributeKindID("nocapture"), 0)) + case "runtime.hashmapGenericGet": + llvmFn.AddAttributeAtIndex(2, c.ctx.CreateEnumAttribute(llvm.AttributeKindID("nocapture"), 0)) + llvmFn.AddAttributeAtIndex(3, c.ctx.CreateEnumAttribute(llvm.AttributeKindID("nocapture"), 0)) + case "runtime.hashmapGenericDelete": + llvmFn.AddAttributeAtIndex(2, c.ctx.CreateEnumAttribute(llvm.AttributeKindID("nocapture"), 0)) case "runtime.trackPointer": // This function is necessary for tracking pointers on the stack in a // portable way (see gc_stack_portable.go). Indicate to the optimizer diff --git a/compiler/testdata/go1.21.ll b/compiler/testdata/go1.21.ll index e8ab03dedb..57c63c90c0 100644 --- a/compiler/testdata/go1.21.ll +++ b/compiler/testdata/go1.21.ll @@ -169,13 +169,13 @@ entry: } ; Function Attrs: nounwind -define hidden void @main.clearMap(ptr dereferenceable_or_null(40) %m, ptr %context) unnamed_addr #2 { +define hidden void @main.clearMap(ptr dereferenceable_or_null(48) %m, ptr %context) unnamed_addr #2 { entry: call void @runtime.hashmapClear(ptr %m, ptr undef) #5 ret void } -declare void @runtime.hashmapClear(ptr dereferenceable_or_null(40), ptr) #1 +declare void @runtime.hashmapClear(ptr dereferenceable_or_null(48), ptr) #1 attributes #0 = { allockind("alloc,zeroed") allocsize(0) "alloc-family"="runtime.alloc" "target-features"="+bulk-memory,+bulk-memory-opt,+call-indirect-overlong,+mutable-globals,+nontrapping-fptoint,+sign-ext,-multivalue,-reference-types" } attributes #1 = { "target-features"="+bulk-memory,+bulk-memory-opt,+call-indirect-overlong,+mutable-globals,+nontrapping-fptoint,+sign-ext,-multivalue,-reference-types" } diff --git a/compiler/testdata/zeromap.ll b/compiler/testdata/zeromap.ll index 058c14fb32..5ce2ebcb8a 100644 --- a/compiler/testdata/zeromap.ll +++ b/compiler/testdata/zeromap.ll @@ -17,7 +17,7 @@ entry: } ; Function Attrs: noinline nounwind -define hidden i32 @main.testZeroGet(ptr dereferenceable_or_null(40) %m, i1 %s.b1, i32 %s.i, i1 %s.b2, ptr %context) unnamed_addr #3 { +define hidden i32 @main.testZeroGet(ptr dereferenceable_or_null(48) %m, i1 %s.b1, i32 %s.i, i1 %s.b2, ptr %context) unnamed_addr #3 { entry: %hashmap.key = alloca %main.hasPadding, align 8 %hashmap.value = alloca i32, align 4 @@ -27,29 +27,23 @@ entry: call void @llvm.lifetime.start.p0(i64 4, ptr nonnull %hashmap.value) call void @llvm.lifetime.start.p0(i64 12, ptr nonnull %hashmap.key) store %main.hasPadding %2, ptr %hashmap.key, align 4 - %3 = getelementptr inbounds nuw i8, ptr %hashmap.key, i32 1 - call void @runtime.memzero(ptr nonnull %3, i32 3, ptr undef) #5 - %4 = getelementptr inbounds nuw i8, ptr %hashmap.key, i32 9 - call void @runtime.memzero(ptr nonnull %4, i32 3, ptr undef) #5 - %5 = call i1 @runtime.hashmapBinaryGet(ptr %m, ptr nonnull %hashmap.key, ptr nonnull %hashmap.value, i32 4, ptr undef) #5 + %3 = call i1 @runtime.hashmapGenericGet(ptr %m, ptr nonnull %hashmap.key, ptr nonnull %hashmap.value, i32 4, ptr undef) #5 call void @llvm.lifetime.end.p0(i64 12, ptr nonnull %hashmap.key) - %6 = load i32, ptr %hashmap.value, align 4 + %4 = load i32, ptr %hashmap.value, align 4 call void @llvm.lifetime.end.p0(i64 4, ptr nonnull %hashmap.value) - ret i32 %6 + ret i32 %4 } ; Function Attrs: nocallback nofree nosync nounwind willreturn memory(argmem: readwrite) declare void @llvm.lifetime.start.p0(i64 immarg, ptr nocapture) #4 -declare void @runtime.memzero(ptr, i32, ptr) #1 - -declare i1 @runtime.hashmapBinaryGet(ptr dereferenceable_or_null(40), ptr, ptr, i32, ptr) #1 +declare i1 @runtime.hashmapGenericGet(ptr dereferenceable_or_null(48), ptr nocapture, ptr nocapture, i32, ptr) #1 ; Function Attrs: nocallback nofree nosync nounwind willreturn memory(argmem: readwrite) declare void @llvm.lifetime.end.p0(i64 immarg, ptr nocapture) #4 ; Function Attrs: noinline nounwind -define hidden void @main.testZeroSet(ptr dereferenceable_or_null(40) %m, i1 %s.b1, i32 %s.i, i1 %s.b2, ptr %context) unnamed_addr #3 { +define hidden void @main.testZeroSet(ptr dereferenceable_or_null(48) %m, i1 %s.b1, i32 %s.i, i1 %s.b2, ptr %context) unnamed_addr #3 { entry: %hashmap.key = alloca %main.hasPadding, align 8 %hashmap.value = alloca i32, align 4 @@ -60,20 +54,16 @@ entry: store i32 5, ptr %hashmap.value, align 4 call void @llvm.lifetime.start.p0(i64 12, ptr nonnull %hashmap.key) store %main.hasPadding %2, ptr %hashmap.key, align 4 - %3 = getelementptr inbounds nuw i8, ptr %hashmap.key, i32 1 - call void @runtime.memzero(ptr nonnull %3, i32 3, ptr undef) #5 - %4 = getelementptr inbounds nuw i8, ptr %hashmap.key, i32 9 - call void @runtime.memzero(ptr nonnull %4, i32 3, ptr undef) #5 - call void @runtime.hashmapBinarySet(ptr %m, ptr nonnull %hashmap.key, ptr nonnull %hashmap.value, ptr undef) #5 + call void @runtime.hashmapGenericSet(ptr %m, ptr nonnull %hashmap.key, ptr nonnull %hashmap.value, ptr undef) #5 call void @llvm.lifetime.end.p0(i64 12, ptr nonnull %hashmap.key) call void @llvm.lifetime.end.p0(i64 4, ptr nonnull %hashmap.value) ret void } -declare void @runtime.hashmapBinarySet(ptr dereferenceable_or_null(40), ptr, ptr, ptr) #1 +declare void @runtime.hashmapGenericSet(ptr dereferenceable_or_null(48), ptr nocapture, ptr nocapture, ptr) #1 ; Function Attrs: noinline nounwind -define hidden i32 @main.testZeroArrayGet(ptr dereferenceable_or_null(40) %m, [2 x %main.hasPadding] %s, ptr %context) unnamed_addr #3 { +define hidden i32 @main.testZeroArrayGet(ptr dereferenceable_or_null(48) %m, [2 x %main.hasPadding] %s, ptr %context) unnamed_addr #3 { entry: %hashmap.key = alloca [2 x %main.hasPadding], align 8 %hashmap.value = alloca i32, align 4 @@ -84,23 +74,15 @@ entry: %hashmap.key.repack1 = getelementptr inbounds nuw i8, ptr %hashmap.key, i32 12 %s.elt2 = extractvalue [2 x %main.hasPadding] %s, 1 store %main.hasPadding %s.elt2, ptr %hashmap.key.repack1, align 4 - %0 = getelementptr inbounds nuw i8, ptr %hashmap.key, i32 1 - call void @runtime.memzero(ptr nonnull %0, i32 3, ptr undef) #5 - %1 = getelementptr inbounds nuw i8, ptr %hashmap.key, i32 9 - call void @runtime.memzero(ptr nonnull %1, i32 3, ptr undef) #5 - %2 = getelementptr inbounds nuw i8, ptr %hashmap.key, i32 13 - call void @runtime.memzero(ptr nonnull %2, i32 3, ptr undef) #5 - %3 = getelementptr inbounds nuw i8, ptr %hashmap.key, i32 21 - call void @runtime.memzero(ptr nonnull %3, i32 3, ptr undef) #5 - %4 = call i1 @runtime.hashmapBinaryGet(ptr %m, ptr nonnull %hashmap.key, ptr nonnull %hashmap.value, i32 4, ptr undef) #5 + %0 = call i1 @runtime.hashmapGenericGet(ptr %m, ptr nonnull %hashmap.key, ptr nonnull %hashmap.value, i32 4, ptr undef) #5 call void @llvm.lifetime.end.p0(i64 24, ptr nonnull %hashmap.key) - %5 = load i32, ptr %hashmap.value, align 4 + %1 = load i32, ptr %hashmap.value, align 4 call void @llvm.lifetime.end.p0(i64 4, ptr nonnull %hashmap.value) - ret i32 %5 + ret i32 %1 } ; Function Attrs: noinline nounwind -define hidden void @main.testZeroArraySet(ptr dereferenceable_or_null(40) %m, [2 x %main.hasPadding] %s, ptr %context) unnamed_addr #3 { +define hidden void @main.testZeroArraySet(ptr dereferenceable_or_null(48) %m, [2 x %main.hasPadding] %s, ptr %context) unnamed_addr #3 { entry: %hashmap.key = alloca [2 x %main.hasPadding], align 8 %hashmap.value = alloca i32, align 4 @@ -112,15 +94,7 @@ entry: %hashmap.key.repack1 = getelementptr inbounds nuw i8, ptr %hashmap.key, i32 12 %s.elt2 = extractvalue [2 x %main.hasPadding] %s, 1 store %main.hasPadding %s.elt2, ptr %hashmap.key.repack1, align 4 - %0 = getelementptr inbounds nuw i8, ptr %hashmap.key, i32 1 - call void @runtime.memzero(ptr nonnull %0, i32 3, ptr undef) #5 - %1 = getelementptr inbounds nuw i8, ptr %hashmap.key, i32 9 - call void @runtime.memzero(ptr nonnull %1, i32 3, ptr undef) #5 - %2 = getelementptr inbounds nuw i8, ptr %hashmap.key, i32 13 - call void @runtime.memzero(ptr nonnull %2, i32 3, ptr undef) #5 - %3 = getelementptr inbounds nuw i8, ptr %hashmap.key, i32 21 - call void @runtime.memzero(ptr nonnull %3, i32 3, ptr undef) #5 - call void @runtime.hashmapBinarySet(ptr %m, ptr nonnull %hashmap.key, ptr nonnull %hashmap.value, ptr undef) #5 + call void @runtime.hashmapGenericSet(ptr %m, ptr nonnull %hashmap.key, ptr nonnull %hashmap.value, ptr undef) #5 call void @llvm.lifetime.end.p0(i64 24, ptr nonnull %hashmap.key) call void @llvm.lifetime.end.p0(i64 4, ptr nonnull %hashmap.value) ret void diff --git a/interp/interp.go b/interp/interp.go index 30b0872485..88854137ac 100644 --- a/interp/interp.go +++ b/interp/interp.go @@ -35,6 +35,7 @@ type runner struct { start time.Time timeout time.Duration callsExecuted uint64 + interpErr error // set by Uint/Int when they encounter pointer data } func newRunner(mod llvm.Module, timeout time.Duration, debug bool) *runner { diff --git a/interp/interpreter.go b/interp/interpreter.go index e8f5545d5d..feb16a10e6 100644 --- a/interp/interpreter.go +++ b/interp/interpreter.go @@ -825,6 +825,14 @@ func (r *runner) run(fn *function, params []value, parentMem *memoryView, indent } return nil, mem, r.errorAt(inst, errUnsupportedInst) } + + // Check if an instruction triggered a recoverable error (e.g., + // trying to interpret pointer data as integer bytes). + if r.interpErr != nil { + err := r.interpErr + r.interpErr = nil + return nil, mem, r.errorAt(inst, err) + } } return nil, mem, r.errorAt(bb.instructions[len(bb.instructions)-1], errors.New("interp: reached end of basic block without terminator")) } diff --git a/interp/memory.go b/interp/memory.go index 2812cd01c2..147bc5f2a0 100644 --- a/interp/memory.go +++ b/interp/memory.go @@ -556,11 +556,13 @@ func (v pointerValue) asRawValue(r *runner) rawValue { } func (v pointerValue) Uint(r *runner) uint64 { - panic("cannot convert pointer to integer") + r.interpErr = errUnsupportedInst + return 0 } func (v pointerValue) Int(r *runner) int64 { - panic("cannot convert pointer to integer") + r.interpErr = errUnsupportedInst + return 0 } func (v pointerValue) equal(rhs pointerValue) bool { @@ -736,11 +738,12 @@ func (v rawValue) asRawValue(r *runner) rawValue { return v } -func (v rawValue) bytes() []byte { +func (v rawValue) bytes(r *runner) []byte { buf := make([]byte, len(v.buf)) for i, p := range v.buf { if p > 255 { - panic("cannot convert pointer value to byte") + r.interpErr = errUnsupportedInst + return buf } buf[i] = byte(p) } @@ -748,7 +751,10 @@ func (v rawValue) bytes() []byte { } func (v rawValue) Uint(r *runner) uint64 { - buf := v.bytes() + buf := v.bytes(r) + if r.interpErr != nil { + return 0 + } switch len(v.buf) { case 1: diff --git a/main_test.go b/main_test.go index 55eb678910..66d18f4aad 100644 --- a/main_test.go +++ b/main_test.go @@ -77,6 +77,7 @@ func TestBuild(t *testing.T) { "interface.go", "json.go", "map.go", + "map_bigkey.go", "math.go", "oldgo/", "print.go", @@ -283,6 +284,11 @@ func runPlatTests(options compileopts.Options, tests []string, t *testing.T) { // limited amount of memory. continue + case "map_bigkey.go": + // Compiler generates many large stack temporaries for [256]byte + // map keys, overflowing the goroutine stack (384 bytes). + continue + case "gc.go": // Does not pass due to high mark false positive rate. continue diff --git a/src/internal/reflectlite/value.go b/src/internal/reflectlite/value.go index 3c2af94f72..115c640744 100644 --- a/src/internal/reflectlite/value.go +++ b/src/internal/reflectlite/value.go @@ -146,6 +146,19 @@ func TypeAssert[T any](v Value) (T, bool) { // valueInterfaceUnsafe is used by the runtime to hash map keys. It should not // be subject to the isExported check. +// loadSmallValue loads a value of size <= sizeof(uintptr) from ptr into +// a pointer-sized value suitable for storing in an interface's data field. +func loadSmallValue(ptr unsafe.Pointer, size uintptr) unsafe.Pointer { + if size == unsafe.Sizeof(uintptr(0)) { + return *(*unsafe.Pointer)(ptr) + } + var value uintptr + for j := size; j != 0; j-- { + value = (value << 8) | uintptr(*(*uint8)(unsafe.Add(ptr, j-1))) + } + return unsafe.Pointer(value) +} + func valueInterfaceUnsafe(v Value) interface{} { if v.typecode.Kind() == Interface { // The value itself is an interface. This can happen when getting the @@ -158,11 +171,7 @@ func valueInterfaceUnsafe(v Value) interface{} { if v.isIndirect() && v.typecode.Size() <= unsafe.Sizeof(uintptr(0)) { // Value was indirect but must be put back directly in the interface // value. - var value uintptr - for j := v.typecode.Size(); j != 0; j-- { - value = (value << 8) | uintptr(*(*uint8)(unsafe.Add(v.value, j-1))) - } - v.value = unsafe.Pointer(value) + v.value = loadSmallValue(v.value, v.typecode.Size()) } return composeInterface(unsafe.Pointer(v.typecode), v.value) } @@ -1085,18 +1094,8 @@ func (v Value) MapKeys() []Value { k := New(v.typecode.Key()) e := New(v.typecode.Elem()) - keyType := v.typecode.key() - keyTypeIsEmptyInterface := keyType.Kind() == Interface && keyType.NumMethod() == 0 - shouldUnpackInterface := !keyTypeIsEmptyInterface && keyType.Kind() != String && !keyType.isBinary() - for hashmapNext(v.pointer(), it, k.value, e.value) { - if shouldUnpackInterface { - intf := *(*interface{})(k.value) - v := ValueOf(intf) - keys = append(keys, v) - } else { - keys = append(keys, k.Elem()) - } + keys = append(keys, k.Elem()) k = New(v.typecode.Key()) } @@ -1109,9 +1108,40 @@ func hashmapStringGet(m unsafe.Pointer, key string, value unsafe.Pointer, valueS //go:linkname hashmapBinaryGet runtime.hashmapBinaryGet func hashmapBinaryGet(m unsafe.Pointer, key, value unsafe.Pointer, valueSize uintptr) bool -//go:linkname hashmapInterfaceGet runtime.hashmapInterfaceGet -func hashmapInterfaceGet(m unsafe.Pointer, key interface{}, value unsafe.Pointer, valueSize uintptr) bool - +//go:linkname hashmapGenericGet runtime.hashmapGenericGet +func hashmapGenericGet(m unsafe.Pointer, key, value unsafe.Pointer, valueSize uintptr) bool + +// genericKeyPtr returns a pointer to key data suitable for passing to the +// hashmapGeneric* functions. When the map's key type is an interface, +// special handling is needed: if the key Value already holds an interface +// (e.g. from MapKeys iteration), its memory already contains the +// {typecode, data} pair the hashmap expects, so we use it directly. +// If the key is a concrete type being assigned to an interface-keyed map, +// we compose the interface first. +func genericKeyPtr(vkey *RawType, key Value) unsafe.Pointer { + if vkey.Kind() == Interface { + if key.Kind() == Interface { + // Key is already an interface value stored indirectly; + // key.value points to {typecode, data}. + return key.value + } + // Concrete value being used as an interface key. + // For small addressable values, key.value is a pointer to + // the data, but the interface value field stores the data + // directly; load it using the same endian-safe approach as + // valueInterfaceUnsafe. + val := key.value + if key.isIndirect() && key.typecode.Size() <= unsafe.Sizeof(uintptr(0)) { + val = loadSmallValue(key.value, key.typecode.Size()) + } + intf := composeInterface(unsafe.Pointer(key.typecode), val) + return unsafe.Pointer(&intf) + } + if key.isIndirect() || key.typecode.Size() > unsafe.Sizeof(uintptr(0)) { + return key.value + } + return unsafe.Pointer(&key.value) +} func (v Value) MapIndex(key Value) Value { if v.Kind() != Map { panic(&ValueError{Method: "MapIndex", Kind: v.Kind()}) @@ -1140,13 +1170,16 @@ func (v Value) MapIndex(key Value) Value { } else { keyptr = unsafe.Pointer(&key.value) } - //TODO(dgryski): zero out padding bytes in key, if any if ok := hashmapBinaryGet(v.pointer(), keyptr, elem.value, elemType.Size()); !ok { return Value{} } return elem.Elem() } else { - if ok := hashmapInterfaceGet(v.pointer(), key.Interface(), elem.value, elemType.Size()); !ok { + // Compiler-generated hash/equal path: keys are stored at their + // actual type. Use hashmapGenericGet which dispatches through the + // map's own keyHash/keyEqual function pointers. + keyptr := genericKeyPtr(vkey, key) + if ok := hashmapGenericGet(v.pointer(), keyptr, elem.value, elemType.Size()); !ok { return Value{} } return elem.Elem() @@ -1171,8 +1204,7 @@ type MapIter struct { key Value val Value - valid bool - unpackKeyInterface bool + valid bool } func (it *MapIter) Key() Value { @@ -1180,12 +1212,6 @@ func (it *MapIter) Key() Value { panic("reflect.MapIter.Key called on invalid iterator") } - if it.unpackKeyInterface { - intf := *(*interface{})(it.key.value) - v := ValueOf(intf) - return v - } - return it.key.Elem() } @@ -1218,15 +1244,9 @@ func (iter *MapIter) Reset(v Value) { panic(&ValueError{Method: "MapRange", Kind: v.Kind()}) } - keyType := v.typecode.key() - - keyTypeIsEmptyInterface := keyType.Kind() == Interface && keyType.NumMethod() == 0 - shouldUnpackInterface := !keyTypeIsEmptyInterface && keyType.Kind() != String && !keyType.isBinary() - *iter = MapIter{ - m: v, - it: hashmapNewIterator(), - unpackKeyInterface: shouldUnpackInterface, + m: v, + it: hashmapNewIterator(), } } @@ -1969,8 +1989,8 @@ func hashmapStringSet(m unsafe.Pointer, key string, value unsafe.Pointer) //go:linkname hashmapBinarySet runtime.hashmapBinarySet func hashmapBinarySet(m unsafe.Pointer, key, value unsafe.Pointer) -//go:linkname hashmapInterfaceSet runtime.hashmapInterfaceSet -func hashmapInterfaceSet(m unsafe.Pointer, key interface{}, value unsafe.Pointer) +//go:linkname hashmapGenericSet runtime.hashmapGenericSet +func hashmapGenericSet(m unsafe.Pointer, key, value unsafe.Pointer) //go:linkname hashmapStringDelete runtime.hashmapStringDelete func hashmapStringDelete(m unsafe.Pointer, key string) @@ -1978,8 +1998,8 @@ func hashmapStringDelete(m unsafe.Pointer, key string) //go:linkname hashmapBinaryDelete runtime.hashmapBinaryDelete func hashmapBinaryDelete(m unsafe.Pointer, key unsafe.Pointer) -//go:linkname hashmapInterfaceDelete runtime.hashmapInterfaceDelete -func hashmapInterfaceDelete(m unsafe.Pointer, key interface{}) +//go:linkname hashmapGenericDelete runtime.hashmapGenericDelete +func hashmapGenericDelete(m unsafe.Pointer, key unsafe.Pointer) func (v Value) SetMapIndex(key, elem Value) { v.checkRO() @@ -2002,15 +2022,19 @@ func (v Value) SetMapIndex(key, elem Value) { } // make elem an interface if it needs to be converted - if v.typecode.elem().Kind() == Interface && elem.typecode.Kind() != Interface { - intf := composeInterface(unsafe.Pointer(elem.typecode), elem.value) + if !del && v.typecode.elem().Kind() == Interface && elem.typecode.Kind() != Interface { + val := elem.value + if elem.isIndirect() && elem.typecode.Size() <= unsafe.Sizeof(uintptr(0)) { + val = loadSmallValue(elem.value, elem.typecode.Size()) + } + intf := composeInterface(unsafe.Pointer(elem.typecode), val) elem = Value{ typecode: v.typecode.elem(), value: unsafe.Pointer(&intf), } } - if key.Kind() == String { + if vkey.Kind() == String { if del { hashmapStringDelete(v.pointer(), *(*string)(key.value)) } else { @@ -2023,7 +2047,7 @@ func (v Value) SetMapIndex(key, elem Value) { hashmapStringSet(v.pointer(), *(*string)(key.value), elemptr) } - } else if key.typecode.isBinary() { + } else if vkey.isBinary() { var keyptr unsafe.Pointer if key.isIndirect() || key.typecode.Size() > unsafe.Sizeof(uintptr(0)) { keyptr = key.value @@ -2043,8 +2067,11 @@ func (v Value) SetMapIndex(key, elem Value) { hashmapBinarySet(v.pointer(), keyptr, elemptr) } } else { + // Compiler-generated hash/equal path. + keyptr := genericKeyPtr(vkey, key) + if del { - hashmapInterfaceDelete(v.pointer(), key.Interface()) + hashmapGenericDelete(v.pointer(), keyptr) } else { var elemptr unsafe.Pointer if elem.isIndirect() || elem.typecode.Size() > unsafe.Sizeof(uintptr(0)) { @@ -2053,7 +2080,7 @@ func (v Value) SetMapIndex(key, elem Value) { elemptr = unsafe.Pointer(&elem.value) } - hashmapInterfaceSet(v.pointer(), key.Interface(), elemptr) + hashmapGenericSet(v.pointer(), keyptr, elemptr) } } } @@ -2110,6 +2137,9 @@ func (v Value) FieldByNameFunc(match func(string) bool) Value { //go:linkname hashmapMake runtime.hashmapMake func hashmapMake(keySize, valueSize uintptr, sizeHint uintptr, alg uint8) unsafe.Pointer +//go:linkname hashmapMakeReflect runtime.hashmapMakeReflect +func hashmapMakeReflect(keySize, valueSize, sizeHint uintptr, keyType unsafe.Pointer) unsafe.Pointer + // MakeMapWithSize creates a new map with the specified type and initial space // for approximately n elements. func MakeMapWithSize(typ Type, n int) Value { @@ -2118,7 +2148,6 @@ func MakeMapWithSize(typ Type, n int) Value { const ( hashmapAlgorithmBinary uint8 = iota hashmapAlgorithmString - hashmapAlgorithmInterface ) if typ.Kind() != Map { @@ -2132,18 +2161,19 @@ func MakeMapWithSize(typ Type, n int) Value { key := typ.Key().(*RawType) val := typ.Elem().(*RawType) - var alg uint8 + var m unsafe.Pointer if key.Kind() == String { - alg = hashmapAlgorithmString + m = hashmapMake(key.Size(), val.Size(), uintptr(n), hashmapAlgorithmString) } else if key.isBinary() { - alg = hashmapAlgorithmBinary + m = hashmapMake(key.Size(), val.Size(), uintptr(n), hashmapAlgorithmBinary) } else { - alg = hashmapAlgorithmInterface + // Composite key type (struct with strings, floats, etc.). + // Use runtime-generated hash/equal closures that walk the + // type structure, matching the compiler-generated functions. + m = hashmapMakeReflect(key.Size(), val.Size(), uintptr(n), unsafe.Pointer(key)) } - m := hashmapMake(key.Size(), val.Size(), uintptr(n), alg) - return Value{ typecode: typ.(*RawType), value: m, diff --git a/src/runtime/hashmap.go b/src/runtime/hashmap.go index 894d92a1ba..ba864db305 100644 --- a/src/runtime/hashmap.go +++ b/src/runtime/hashmap.go @@ -13,15 +13,26 @@ import ( // The underlying hashmap structure for Go. type hashmap struct { - buckets unsafe.Pointer // pointer to array of buckets - seed uintptr - count uintptr - keySize uintptr // maybe this can store the key type as well? E.g. keysize == 5 means string? - valueSize uintptr - bucketBits uint8 - keyEqual func(x, y unsafe.Pointer, n uintptr) bool - keyHash func(key unsafe.Pointer, size, seed uintptr) uint32 -} + buckets unsafe.Pointer // pointer to array of buckets + seed uintptr + count uintptr + keySize uintptr + valueSize uintptr + keySlotSize uintptr // == keySize, or sizeof(ptr) if indirect + valueSlotSize uintptr // == valueSize, or sizeof(ptr) if indirect + bucketBits uint8 + flags uint8 + keyEqual func(x, y unsafe.Pointer, n uintptr) bool + keyHash func(key unsafe.Pointer, size, seed uintptr) uint32 +} + +const ( + hashmapMaxKeySize = 128 + hashmapMaxValueSize = 128 + + hashmapFlagIndirectKey = 1 << 0 + hashmapFlagIndirectValue = 1 << 1 +) // A hashmap bucket. A bucket is a container of 8 key/value pairs: first the // following two entries, then the 8 keys, then the 8 values. This somewhat odd @@ -34,6 +45,48 @@ type hashmapBucket struct { // allocated but as they're of variable size they can't be shown here. } +// hashmapBucketHeaderSize is the offset in bytes from the start of a bucket to +// the first key, aligned to 8 bytes. This ensures that keys requiring 8-byte +// alignment (float64, complex128, uint64 on strict-alignment architectures +// like MIPS) are properly aligned in the bucket. +const hashmapBucketHeaderSize = (unsafe.Sizeof(hashmapBucket{}) + 7) &^ 7 + +// hashmapKeySlotSize returns the size of a key slot in the bucket. For indirect +// keys, this is the pointer size; otherwise the actual key size. +// +//go:inline +func hashmapKeySlotSize(m *hashmap) uintptr { + return m.keySlotSize +} + +// hashmapValueSlotSize returns the size of a value slot in the bucket. +// +//go:inline +func hashmapValueSlotSize(m *hashmap) uintptr { + return m.valueSlotSize +} + +// hashmapSlotKeyData returns a pointer to the actual key data for a given slot. +// For indirect keys, the slot contains a pointer that must be dereferenced. +// +//go:inline +func hashmapSlotKeyData(m *hashmap, slotKey unsafe.Pointer) unsafe.Pointer { + if m.flags&hashmapFlagIndirectKey != 0 { + return *(*unsafe.Pointer)(slotKey) + } + return slotKey +} + +// hashmapSlotValueData returns a pointer to the actual value data for a given slot. +// +//go:inline +func hashmapSlotValueData(m *hashmap, slotValue unsafe.Pointer) unsafe.Pointer { + if m.flags&hashmapFlagIndirectValue != 0 { + return *(*unsafe.Pointer)(slotValue) + } + return slotValue +} + type hashmapIterator struct { buckets unsafe.Pointer // pointer to array of hashapBuckets numBuckets uintptr // length of buckets array @@ -66,20 +119,35 @@ func hashmapMake(keySize, valueSize uintptr, sizeHint uintptr, alg uint8) *hashm bucketBits++ } - bucketBufSize := unsafe.Sizeof(hashmapBucket{}) + keySize*8 + valueSize*8 + var flags uint8 + keySlotSize := keySize + if keySize > hashmapMaxKeySize { + flags |= hashmapFlagIndirectKey + keySlotSize = unsafe.Sizeof(unsafe.Pointer(nil)) + } + valueSlotSize := valueSize + if valueSize > hashmapMaxValueSize { + flags |= hashmapFlagIndirectValue + valueSlotSize = unsafe.Sizeof(unsafe.Pointer(nil)) + } + + bucketBufSize := hashmapBucketHeaderSize + keySlotSize*8 + valueSlotSize*8 buckets := alloc(bucketBufSize*(1< hashmapMaxKeySize { + flags |= hashmapFlagIndirectKey + keySlotSize = unsafe.Sizeof(unsafe.Pointer(nil)) + } + valueSlotSize := valueSize + if valueSize > hashmapMaxValueSize { + flags |= hashmapFlagIndirectValue + valueSlotSize = unsafe.Sizeof(unsafe.Pointer(nil)) + } + + bucketBufSize := hashmapBucketHeaderSize + keySlotSize*8 + valueSlotSize*8 + buckets := alloc(bucketBufSize*(1< 1 { + pm := make(map[paddedKey]int) + var pk1, pk2 paddedKey + pk1.A = 1; pk1.B = 42 + pk2.A = 1; pk2.B = 42 + // Poison pk2's padding byte (between A and B). + *(*byte)(unsafe.Add(unsafe.Pointer(&pk2), 1)) = 0xFF + pm[pk1] = 100 + println("padded key lookup:", pm[pk2]) // 100 + println("padded key equal:", pk1 == pk2) // true + } else { + // No padding on this platform; print expected output. + println("padded key lookup:", 100) + println("padded key equal:", true) + } + + // Struct keys with blank fields: blank fields are ignored in equality. + type blankKey struct { + _ int + X string + } + bm := make(map[blankKey]int) + var bk1, bk2 blankKey + bk1.X = "hello" + bk2.X = "hello" + *(*int)(unsafe.Pointer(&bk2)) = 999 + bm[bk1] = 200 + println("blank key lookup:", bm[bk2]) // 200 + println("blank key equal:", bk1 == bk2) // true +} diff --git a/testdata/map.txt b/testdata/map.txt index d5e553b1a7..7a2f587455 100644 --- a/testdata/map.txt +++ b/testdata/map.txt @@ -80,4 +80,14 @@ tested growing of a map 2 2 done +ptr map literal: 1 2 +unsafe ptr literal: 10 +ptr map len: 3 +ptr map a: 10 +ptr map deleted: false +unsafe ptr map: 100 +padded key lookup: 100 +padded key equal: true +blank key lookup: 200 +blank key equal: true no interface lookup failures diff --git a/testdata/map_bigkey.go b/testdata/map_bigkey.go new file mode 100644 index 0000000000..2d2dbec4cb --- /dev/null +++ b/testdata/map_bigkey.go @@ -0,0 +1,63 @@ +package main + +// Test maps with keys and values larger than 128 bytes, which triggers +// indirect storage in the bucket (pointers instead of inline data). +// +// This is a separate file from map.go because the compiler generates many +// large stack temporaries for map operations on [256]byte keys, which +// overflows the goroutine stack on AVR (384 bytes). AVR skips this test. + +type BigKey [256]byte +type BigValue [256]byte + +func main() { + // Large key, small value. + m1 := make(map[BigKey]int) + var k1 BigKey + k1[0] = 1 + k1[255] = 42 + m1[k1] = 100 + + var k1same BigKey + k1same[0] = 1 + k1same[255] = 42 + + var k1diff BigKey + k1diff[0] = 2 + + println("bigkey get:", m1[k1]) + println("bigkey get same:", m1[k1same]) + println("bigkey get diff:", m1[k1diff]) + + // Overwrite. + m1[k1] = 200 + println("bigkey overwrite:", m1[k1]) + + // Small key, large value. + m2 := make(map[int]BigValue) + var v BigValue + v[0] = 7 + v[255] = 99 + m2[1] = v + got := m2[1] + println("bigval get:", got[0], got[255]) + + // Both large. + m3 := make(map[BigKey]BigValue) + m3[k1] = v + got3 := m3[k1] + println("bigboth get:", got3[0], got3[255]) + + // Delete. + delete(m3, k1) + got3 = m3[k1] + println("bigboth deleted:", got3[0]) + + // Iteration. + m1[k1diff] = 300 + count := 0 + for range m1 { + count++ + } + println("bigkey len:", len(m1), "iterated:", count) +} diff --git a/testdata/map_bigkey.txt b/testdata/map_bigkey.txt new file mode 100644 index 0000000000..c16d116e11 --- /dev/null +++ b/testdata/map_bigkey.txt @@ -0,0 +1,8 @@ +bigkey get: 100 +bigkey get same: 100 +bigkey get diff: 0 +bigkey overwrite: 200 +bigval get: 7 99 +bigboth get: 7 99 +bigboth deleted: 0 +bigkey len: 2 iterated: 2 diff --git a/testdata/reflect.go b/testdata/reflect.go index 873d60f787..476ccd1148 100644 --- a/testdata/reflect.go +++ b/testdata/reflect.go @@ -396,6 +396,24 @@ func main() { } } } + + // Test for issue #3794: reflect MapIter.Key() should return a value with + // interface kind for map[interface{}] keys, not the underlying concrete kind. + { + m := make(map[interface{}]int) + m[1] = 2 + m["hello"] = 3 + rv := reflect.ValueOf(m) + iter := rv.MapRange() + for iter.Next() { + k := iter.Key() + if k.Kind() != reflect.Interface { + println("FAIL #3794: expected interface kind, got", k.Kind().String()) + break + } + } + println("reflect map interface key ok") + } } func emptyFunc() { @@ -755,6 +773,9 @@ func testImplements() { // Make FooNode and BarNode implement Node with pointer receivers // (can't add methods to local types in function, use a different approach) testValueSetInterface() + testMakeMapCompositeKey() + testMakeMapInterfaceKey() + testMakeMapPaddedKey() } type IfaceNode interface { @@ -807,3 +828,122 @@ func randuint32() uint32 { xorshift32State = xorshift32(xorshift32State) return xorshift32State } + +type compositeKey struct { + S string + N int32 +} + +// testMakeMapCompositeKey tests that reflect.MakeMap works correctly with +// composite key types (structs containing strings). This exercises the +// hash/equal dispatch path for maps created through reflection rather +// than by the compiler. +func testMakeMapCompositeKey() { + println("\nreflect.MakeMap composite key:") + mapType := reflect.TypeOf(map[compositeKey]int{}) + m := reflect.MakeMap(mapType) + + // Insert two keys that share the same string but differ in the int field. + key1 := reflect.ValueOf(compositeKey{S: "hello", N: 1}) + key2 := reflect.ValueOf(compositeKey{S: "hello", N: 2}) + m.SetMapIndex(key1, reflect.ValueOf(100)) + m.SetMapIndex(key2, reflect.ValueOf(200)) + + println("len:", m.Len()) + + v1 := m.MapIndex(key1) + if v1.IsValid() { + println("key1:", v1.Int()) + } else { + println("key1: not found") + } + v2 := m.MapIndex(key2) + if v2.IsValid() { + println("key2:", v2.Int()) + } else { + println("key2: not found") + } + + // Delete key1, verify key2 remains. + m.SetMapIndex(key1, reflect.Value{}) + println("after delete, len:", m.Len()) + v2 = m.MapIndex(key2) + if v2.IsValid() { + println("key2 after delete:", v2.Int()) + } else { + println("key2 after delete: not found") + } +} + +// testMakeMapInterfaceKey tests that reflect.MakeMap works correctly with +// interface{} key types, including cross-path usage (reflect insert, +// compiled lookup and vice versa). +func testMakeMapInterfaceKey() { + println("\nreflect.MakeMap interface key:") + mapType := reflect.TypeOf(map[interface{}]int{}) + rv := reflect.MakeMap(mapType) + + rv.SetMapIndex(reflect.ValueOf(42), reflect.ValueOf(100)) + rv.SetMapIndex(reflect.ValueOf("hello"), reflect.ValueOf(200)) + println("len:", rv.Len()) + + v1 := rv.MapIndex(reflect.ValueOf(42)) + if v1.IsValid() { + println("42:", v1.Int()) + } else { + println("42: not found") + } + v2 := rv.MapIndex(reflect.ValueOf("hello")) + if v2.IsValid() { + println("hello:", v2.Int()) + } else { + println("hello: not found") + } + + // Cross-path: use from compiled code. + m := rv.Interface().(map[interface{}]int) + println("compiled 42:", m[42]) + println("compiled hello:", m["hello"]) + + // Addressable small value as key. + x := 99 + addrVal := reflect.ValueOf(&x).Elem() + rv.SetMapIndex(addrVal, reflect.ValueOf(300)) + v3 := rv.MapIndex(reflect.ValueOf(99)) + if v3.IsValid() { + println("addressable 99:", v3.Int()) + } else { + println("addressable 99: not found") + } +} + +type paddedKey struct { + A int8 + B int32 +} + +// testMakeMapPaddedKey tests that struct keys with padding work correctly +// through reflect, using addressable values with poisoned padding bytes. +func testMakeMapPaddedKey() { + println("\nreflect.MakeMap padded key:") + var pk1, pk2 paddedKey + pk1.A = 1 + pk1.B = 42 + pk2.A = 1 + pk2.B = 42 + + if unsafe.Offsetof(paddedKey{}.B) > 1 { + // Poison pk2's padding byte (between A and B). + *(*byte)(unsafe.Add(unsafe.Pointer(&pk2), 1)) = 0xFF + } + + // Use addressable values so padding survives into reflect. + rm := reflect.MakeMap(reflect.TypeOf(map[paddedKey]int{})) + rm.SetMapIndex(reflect.ValueOf(&pk1).Elem(), reflect.ValueOf(100)) + v := rm.MapIndex(reflect.ValueOf(&pk2).Elem()) + if v.IsValid() { + println("padded lookup:", v.Int()) + } else { + println("padded lookup: not found") + } +} diff --git a/testdata/reflect.txt b/testdata/reflect.txt index 3024568c3d..dbd992e870 100644 --- a/testdata/reflect.txt +++ b/testdata/reflect.txt @@ -503,6 +503,24 @@ value set interface: Set[0] to BarNode: 10 Set[1] still FooNode: 2 +reflect.MakeMap composite key: +len: 2 +key1: 100 +key2: 200 +after delete, len: 1 +key2 after delete: 200 + +reflect.MakeMap interface key: +len: 2 +42: 100 +hello: 200 +compiled 42: 100 +compiled hello: 200 +addressable 99: 300 + +reflect.MakeMap padded key: +padded lookup: 100 + alignment / offset: struct{[0]func(); byte}: true @@ -512,3 +530,4 @@ blue gopher v.Interface() method kind: interface int 5 +reflect map interface key ok diff --git a/tests/mapbench/mapbench_test.go b/tests/mapbench/mapbench_test.go new file mode 100644 index 0000000000..57c287876b --- /dev/null +++ b/tests/mapbench/mapbench_test.go @@ -0,0 +1,98 @@ +package mapbench + +import "testing" + +type compositeKey struct { + S string + N int32 +} + +var intSink int + +func BenchmarkMapStringShortGet(b *testing.B) { + m := make(map[string]int, 100) + for i := 0; i < 100; i++ { + m[string(rune('A'+i%26))+string(rune('a'+i/26))] = i + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + intSink += m["Qa"] + } +} + +func BenchmarkMapStringLongGet(b *testing.B) { + m := make(map[string]int, 100) + for i := 0; i < 100; i++ { + s := "this-is-a-longer-key-for-testing-" + for j := 0; j < 3; j++ { + s += string(rune('A' + (i+j)%26)) + } + m[s] = i + } + key := "this-is-a-longer-key-for-testing-ABC" + b.ResetTimer() + for i := 0; i < b.N; i++ { + intSink += m[key] + } +} + +func BenchmarkMapCompositeGet(b *testing.B) { + m := make(map[compositeKey]int, 100) + for i := 0; i < 100; i++ { + m[compositeKey{S: string(rune('A'+i%26)) + string(rune('a'+i/26)), N: int32(i)}] = i + } + key := compositeKey{S: "Qa", N: 42} + b.ResetTimer() + for i := 0; i < b.N; i++ { + intSink += m[key] + } +} + +func BenchmarkMapIntGet(b *testing.B) { + m := make(map[int]int, 100) + for i := 0; i < 100; i++ { + m[i*7] = i + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + intSink += m[42] + } +} + +func BenchmarkMapCompositeSet(b *testing.B) { + m := make(map[compositeKey]int, b.N) + b.ResetTimer() + for i := 0; i < b.N; i++ { + m[compositeKey{S: "key", N: int32(i)}] = i + } +} + +type bigKey [256]byte + +func BenchmarkMapBigKeyGet(b *testing.B) { + m := make(map[bigKey]int, 100) + for i := 0; i < 100; i++ { + var k bigKey + k[0] = byte(i) + m[k] = i + } + var k bigKey + k[0] = 42 + b.ResetTimer() + for i := 0; i < b.N; i++ { + intSink += m[k] + } +} + +func BenchmarkMapBigKeySet(b *testing.B) { + m := make(map[bigKey]int, b.N) + b.ResetTimer() + for i := 0; i < b.N; i++ { + var k bigKey + k[0] = byte(i) + k[1] = byte(i >> 8) + k[2] = byte(i >> 16) + k[3] = byte(i >> 24) + m[k] = i + } +} diff --git a/transform/maps.go b/transform/maps.go index 359d9cc575..a078137e4c 100644 --- a/transform/maps.go +++ b/transform/maps.go @@ -10,38 +10,44 @@ import ( // maps. This has not yet been implemented, however. func OptimizeMaps(mod llvm.Module) { hashmapMake := mod.NamedFunction("runtime.hashmapMake") - if hashmapMake.IsNil() { - // nothing to optimize - return - } + hashmapMakeGeneric := mod.NamedFunction("runtime.hashmapMakeGeneric") hashmapBinarySet := mod.NamedFunction("runtime.hashmapBinarySet") hashmapStringSet := mod.NamedFunction("runtime.hashmapStringSet") + hashmapGenericSet := mod.NamedFunction("runtime.hashmapGenericSet") - for _, makeInst := range getUses(hashmapMake) { - updateInsts := []llvm.Value{} - unknownUses := false // are there any uses other than setting a value? + optimizeMapMake := func(makeFunc llvm.Value) { + if makeFunc.IsNil() { + return + } + for _, makeInst := range getUses(makeFunc) { + updateInsts := []llvm.Value{} + unknownUses := false - for _, use := range getUses(makeInst) { - if use := use.IsACallInst(); !use.IsNil() { - switch use.CalledValue() { - case hashmapBinarySet, hashmapStringSet: - updateInsts = append(updateInsts, use) - default: + for _, use := range getUses(makeInst) { + if use := use.IsACallInst(); !use.IsNil() { + switch use.CalledValue() { + case hashmapBinarySet, hashmapStringSet, hashmapGenericSet: + updateInsts = append(updateInsts, use) + default: + unknownUses = true + } + } else { unknownUses = true } - } else { - unknownUses = true } - } - if !unknownUses { - // This map can be entirely removed, as it is only created but never - // used. - for _, inst := range updateInsts { - inst.EraseFromParentAsInstruction() + if !unknownUses { + // This map can be entirely removed, as it is only created + // but never used. + for _, inst := range updateInsts { + inst.EraseFromParentAsInstruction() + } + makeInst.EraseFromParentAsInstruction() } - makeInst.EraseFromParentAsInstruction() } } + + optimizeMapMake(hashmapMake) + optimizeMapMake(hashmapMakeGeneric) }