@@ -78,6 +78,7 @@ module Data.HashMap.Internal
7878 , intersection
7979 , intersectionWith
8080 , intersectionWithKey
81+ , intersectionWithKey #
8182
8283 -- * Folds
8384 , foldr'
@@ -150,9 +151,9 @@ import Data.Data (Constr, Data (..), DataType)
150151import Data.Functor.Classes (Eq1 (.. ), Eq2 (.. ), Ord1 (.. ), Ord2 (.. ),
151152 Read1 (.. ), Show1 (.. ), Show2 (.. ))
152153import Data.Functor.Identity (Identity (.. ))
153- import Data.HashMap.Internal.List (isPermutationBy , unorderedCompare )
154154import Data.Hashable (Hashable )
155155import Data.Hashable.Lifted (Hashable1 , Hashable2 )
156+ import Data.HashMap.Internal.List (isPermutationBy , unorderedCompare )
156157import Data.Semigroup (Semigroup (.. ), stimesIdempotentMonoid )
157158import GHC.Exts (Int (.. ), Int #, TYPE , (==#) )
158159import GHC.Stack (HasCallStack )
@@ -163,9 +164,9 @@ import Text.Read hiding (step)
163164import qualified Data.Data as Data
164165import qualified Data.Foldable as Foldable
165166import qualified Data.Functor.Classes as FC
166- import qualified Data.HashMap.Internal.Array as A
167167import qualified Data.Hashable as H
168168import qualified Data.Hashable.Lifted as H
169+ import qualified Data.HashMap.Internal.Array as A
169170import qualified Data.List as List
170171import qualified GHC.Exts as Exts
171172import qualified Language.Haskell.TH.Syntax as TH
@@ -1627,7 +1628,7 @@ unionArrayBy f !b1 !b2 !ary1 !ary2 = A.run $ do
16271628 A. write mary i =<< A. indexM ary2 i2
16281629 go (i+ 1 ) i1 (i2+ 1 ) b'
16291630 where
1630- m = 1 `unsafeShiftL` ( countTrailingZeros b)
1631+ m = 1 `unsafeShiftL` countTrailingZeros b
16311632 testBit x = x .&. m /= 0
16321633 b' = b .&. complement m
16331634 go 0 0 0 bCombined
@@ -1759,37 +1760,161 @@ differenceWith f a b = foldlWithKey' go empty a
17591760-- | \(O(n \log m)\) Intersection of two maps. Return elements of the first
17601761-- map for keys existing in the second.
17611762intersection :: (Eq k , Hashable k ) => HashMap k v -> HashMap k w -> HashMap k v
1762- intersection a b = foldlWithKey' go empty a
1763- where
1764- go m k v = case lookup k b of
1765- Just _ -> unsafeInsert k v m
1766- _ -> m
1763+ intersection = Exts. inline intersectionWith const
17671764{-# INLINABLE intersection #-}
17681765
17691766-- | \(O(n \log m)\) Intersection of two maps. If a key occurs in both maps
17701767-- the provided function is used to combine the values from the two
17711768-- maps.
1772- intersectionWith :: (Eq k , Hashable k ) => (v1 -> v2 -> v3 ) -> HashMap k v1
1773- -> HashMap k v2 -> HashMap k v3
1774- intersectionWith f a b = foldlWithKey' go empty a
1775- where
1776- go m k v = case lookup k b of
1777- Just w -> unsafeInsert k (f v w) m
1778- _ -> m
1769+ intersectionWith :: (Eq k , Hashable k ) => (v1 -> v2 -> v3 ) -> HashMap k v1 -> HashMap k v2 -> HashMap k v3
1770+ intersectionWith f = Exts. inline intersectionWithKey $ const f
17791771{-# INLINABLE intersectionWith #-}
17801772
17811773-- | \(O(n \log m)\) Intersection of two maps. If a key occurs in both maps
17821774-- the provided function is used to combine the values from the two
17831775-- maps.
1784- intersectionWithKey :: (Eq k , Hashable k ) => (k -> v1 -> v2 -> v3 )
1785- -> HashMap k v1 -> HashMap k v2 -> HashMap k v3
1786- intersectionWithKey f a b = foldlWithKey' go empty a
1787- where
1788- go m k v = case lookup k b of
1789- Just w -> unsafeInsert k (f k v w) m
1790- _ -> m
1776+ intersectionWithKey :: (Eq k , Hashable k ) => (k -> v1 -> v2 -> v3 ) -> HashMap k v1 -> HashMap k v2 -> HashMap k v3
1777+ intersectionWithKey f = intersectionWithKey# $ \ k v1 v2 -> (# f k v1 v2 # )
17911778{-# INLINABLE intersectionWithKey #-}
17921779
1780+ intersectionWithKey# :: Eq k => (k -> v1 -> v2 -> (# v3 # )) -> HashMap k v1 -> HashMap k v2 -> HashMap k v3
1781+ intersectionWithKey# f = go 0
1782+ where
1783+ -- empty vs. anything
1784+ go ! _ _ Empty = Empty
1785+ go _ Empty _ = Empty
1786+ -- leaf vs. anything
1787+ go s (Leaf h1 (L k1 v1)) t2 =
1788+ lookupCont
1789+ (\ _ -> Empty )
1790+ (\ v _ -> case f k1 v1 v of (# v' # ) -> Leaf h1 $ L k1 v')
1791+ h1 k1 s t2
1792+ go s t1 (Leaf h2 (L k2 v2)) =
1793+ lookupCont
1794+ (\ _ -> Empty )
1795+ (\ v _ -> case f k2 v v2 of (# v' # ) -> Leaf h2 $ L k2 v')
1796+ h2 k2 s t1
1797+ -- collision vs. collision
1798+ go _ (Collision h1 ls1) (Collision h2 ls2) = intersectionCollisions f h1 h2 ls1 ls2
1799+ -- branch vs. branch
1800+ go s (BitmapIndexed b1 ary1) (BitmapIndexed b2 ary2) =
1801+ intersectionArrayBy (go (s + bitsPerSubkey)) b1 b2 ary1 ary2
1802+ go s (BitmapIndexed b1 ary1) (Full ary2) =
1803+ intersectionArrayBy (go (s + bitsPerSubkey)) b1 fullNodeMask ary1 ary2
1804+ go s (Full ary1) (BitmapIndexed b2 ary2) =
1805+ intersectionArrayBy (go (s + bitsPerSubkey)) fullNodeMask b2 ary1 ary2
1806+ go s (Full ary1) (Full ary2) =
1807+ intersectionArrayBy (go (s + bitsPerSubkey)) fullNodeMask fullNodeMask ary1 ary2
1808+ -- collision vs. branch
1809+ go s (BitmapIndexed b1 ary1) t2@ (Collision h2 _ls2)
1810+ | b1 .&. m2 == 0 = Empty
1811+ | otherwise = go (s + bitsPerSubkey) (A. index ary1 i) t2
1812+ where
1813+ m2 = mask h2 s
1814+ i = sparseIndex b1 m2
1815+ go s t1@ (Collision h1 _ls1) (BitmapIndexed b2 ary2)
1816+ | b2 .&. m1 == 0 = Empty
1817+ | otherwise = go (s + bitsPerSubkey) t1 (A. index ary2 i)
1818+ where
1819+ m1 = mask h1 s
1820+ i = sparseIndex b2 m1
1821+ go s (Full ary1) t2@ (Collision h2 _ls2) = go (s + bitsPerSubkey) (A. index ary1 i) t2
1822+ where
1823+ i = index h2 s
1824+ go s t1@ (Collision h1 _ls1) (Full ary2) = go (s + bitsPerSubkey) t1 (A. index ary2 i)
1825+ where
1826+ i = index h1 s
1827+ {-# INLINE intersectionWithKey# #-}
1828+
1829+ intersectionArrayBy ::
1830+ ( HashMap k v1 ->
1831+ HashMap k v2 ->
1832+ HashMap k v3
1833+ ) ->
1834+ Bitmap ->
1835+ Bitmap ->
1836+ A. Array (HashMap k v1 ) ->
1837+ A. Array (HashMap k v2 ) ->
1838+ HashMap k v3
1839+ intersectionArrayBy f ! b1 ! b2 ! ary1 ! ary2
1840+ | b1 .&. b2 == 0 = Empty
1841+ | otherwise = runST $ do
1842+ mary <- A. new_ $ popCount bIntersect
1843+ -- iterate over nonzero bits of b1 .|. b2
1844+ let go ! i ! i1 ! i2 ! b ! bFinal
1845+ | b == 0 = pure (i, bFinal)
1846+ | testBit $ b1 .&. b2 = do
1847+ x1 <- A. indexM ary1 i1
1848+ x2 <- A. indexM ary2 i2
1849+ case f x1 x2 of
1850+ Empty -> go i (i1 + 1 ) (i2 + 1 ) b' (bFinal .&. complement m)
1851+ _ -> do
1852+ A. write mary i $! f x1 x2
1853+ go (i + 1 ) (i1 + 1 ) (i2 + 1 ) b' bFinal
1854+ | testBit b1 = go i (i1 + 1 ) i2 b' bFinal
1855+ | otherwise = go i i1 (i2 + 1 ) b' bFinal
1856+ where
1857+ m = 1 `unsafeShiftL` countTrailingZeros b
1858+ testBit x = x .&. m /= 0
1859+ b' = b .&. complement m
1860+ (len, bFinal) <- go 0 0 0 bCombined bIntersect
1861+ case len of
1862+ 0 -> pure Empty
1863+ 1 -> A. read mary 0
1864+ _ -> bitmapIndexedOrFull bFinal <$> (A. unsafeFreeze =<< A. shrink mary len)
1865+ where
1866+ bCombined = b1 .|. b2
1867+ bIntersect = b1 .&. b2
1868+ {-# INLINE intersectionArrayBy #-}
1869+
1870+ intersectionCollisions :: Eq k => (k -> v1 -> v2 -> (# v3 # )) -> Hash -> Hash -> A. Array (Leaf k v1 ) -> A. Array (Leaf k v2 ) -> HashMap k v3
1871+ intersectionCollisions f h1 h2 ary1 ary2
1872+ | h1 == h2 = runST $ do
1873+ mary2 <- A. thaw ary2 0 $ A. length ary2
1874+ mary <- A. new_ $ min (A. length ary1) (A. length ary2)
1875+ let go i j
1876+ | i >= A. length ary1 || j >= A. lengthM mary2 = pure j
1877+ | otherwise = do
1878+ L k1 v1 <- A. indexM ary1 i
1879+ searchSwap k1 j mary2 >>= \ case
1880+ Just (L _k2 v2) -> do
1881+ let ! (# v3 # ) = f k1 v1 v2
1882+ A. write mary j $ L k1 v3
1883+ go (i + 1 ) (j + 1 )
1884+ Nothing -> do
1885+ go (i + 1 ) j
1886+ len <- go 0 0
1887+ case len of
1888+ 0 -> pure Empty
1889+ 1 -> Leaf h1 <$> A. read mary 0
1890+ _ -> Collision h1 <$> (A. unsafeFreeze =<< A. shrink mary len)
1891+ | otherwise = Empty
1892+ {-# INLINE intersectionCollisions #-}
1893+
1894+ -- | Say we have
1895+ -- @
1896+ -- 1 2 3 4
1897+ -- @
1898+ -- and we search for @3@. Then we can mutate the array to
1899+ -- @
1900+ -- undefined 2 1 4
1901+ -- @
1902+ -- We don't actually need to write undefined, we just have to make sure that the next search starts 1 after the current one.
1903+ searchSwap :: Eq k => k -> Int -> A. MArray s (Leaf k v ) -> ST s (Maybe (Leaf k v ))
1904+ searchSwap toFind start = go start toFind start
1905+ where
1906+ go i0 k i mary
1907+ | i >= A. lengthM mary = pure Nothing
1908+ | otherwise = do
1909+ l@ (L k' _v) <- A. read mary i
1910+ if k == k'
1911+ then do
1912+ A. write mary i =<< A. read mary i0
1913+ pure $ Just l
1914+ else go i0 k (i + 1 ) mary
1915+ {-# INLINE searchSwap #-}
1916+
1917+
17931918------------------------------------------------------------------------
17941919-- * Folds
17951920
0 commit comments