Skip to content

Commit c40a0d5

Browse files
authored
Make array creation friendlier to worker-wrapper (#104)
* Let GHC can erase `Array` constructors in many cases. * Add a `runArray` function like the ones in `array` and `vector`. Closes #102
1 parent 9df14ca commit c40a0d5

File tree

2 files changed

+141
-39
lines changed

2 files changed

+141
-39
lines changed

Data/Primitive/Array.hs

Lines changed: 73 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ module Data.Primitive.Array (
1717
Array(..), MutableArray(..),
1818

1919
newArray, readArray, writeArray, indexArray, indexArrayM,
20-
freezeArray, thawArray,
20+
freezeArray, thawArray, runArray,
2121
unsafeFreezeArray, unsafeThawArray, sameMutableArray,
2222
copyArray, copyMutableArray,
2323
cloneArray, cloneMutableArray,
@@ -54,12 +54,18 @@ import Data.Traversable (Traversable(..))
5454
import Data.Monoid
5555
#endif
5656
#if MIN_VERSION_base(4,9,0)
57+
import qualified GHC.ST as GHCST
5758
import qualified Data.Foldable as F
5859
import Data.Semigroup
5960
#endif
6061
#if MIN_VERSION_base(4,8,0)
6162
import Data.Functor.Identity
6263
#endif
64+
#if MIN_VERSION_base(4,10,0)
65+
import GHC.Exts (runRW#)
66+
#elif MIN_VERSION_base(4,9,0)
67+
import GHC.Base (runRW#)
68+
#endif
6369

6470
import Text.ParserCombinators.ReadP
6571

@@ -278,16 +284,63 @@ emptyArray =
278284
runST $ newArray 0 (die "emptyArray" "impossible") >>= unsafeFreezeArray
279285
{-# NOINLINE emptyArray #-}
280286

287+
#if !MIN_VERSION_base(4,9,0)
281288
createArray
282289
:: Int
283290
-> a
284291
-> (forall s. MutableArray s a -> ST s ())
285292
-> Array a
286293
createArray 0 _ _ = emptyArray
287-
createArray n x f = runST $ do
288-
ma <- newArray n x
289-
f ma
290-
unsafeFreezeArray ma
294+
createArray n x f = runArray $ do
295+
mary <- newArray n x
296+
f mary
297+
pure mary
298+
299+
runArray
300+
:: (forall s. ST s (MutableArray s a))
301+
-> Array a
302+
runArray m = runST $ m >>= unsafeFreezeArray
303+
304+
#else /* Below, runRW# is available. */
305+
306+
-- This low-level business is designed to work with GHC's worker-wrapper
307+
-- transformation. A lot of the time, we don't actually need an Array
308+
-- constructor. By putting it on the outside, and being careful about
309+
-- how we special-case the empty array, we can make GHC smarter about this.
310+
-- The only downside is that separately created 0-length arrays won't share
311+
-- their Array constructors, although they'll share their underlying
312+
-- Array#s.
313+
createArray
314+
:: Int
315+
-> a
316+
-> (forall s. MutableArray s a -> ST s ())
317+
-> Array a
318+
createArray 0 _ _ = Array (emptyArray# (# #))
319+
createArray n x f = runArray $ do
320+
mary <- newArray n x
321+
f mary
322+
pure mary
323+
324+
runArray
325+
:: (forall s. ST s (MutableArray s a))
326+
-> Array a
327+
runArray m = Array (runArray# m)
328+
329+
runArray#
330+
:: (forall s. ST s (MutableArray s a))
331+
-> Array# a
332+
runArray# m = case runRW# $ \s ->
333+
case unST m s of { (# s', MutableArray mary# #) ->
334+
unsafeFreezeArray# mary# s'} of (# _, ary# #) -> ary#
335+
336+
unST :: ST s a -> State# s -> (# State# s, a #)
337+
unST (GHCST.ST f) = f
338+
339+
emptyArray# :: (# #) -> Array# a
340+
emptyArray# _ = case emptyArray of Array ar -> ar
341+
{-# NOINLINE emptyArray# #-}
342+
#endif
343+
291344

292345
die :: String -> String -> a
293346
die fun problem = error $ "Data.Primitive.Array." ++ fun ++ ": " ++ problem
@@ -507,18 +560,17 @@ unsafeTraverseArray f = \ !ary ->
507560
{-# INLINE unsafeTraverseArray #-}
508561

509562
arrayFromListN :: Int -> [a] -> Array a
510-
arrayFromListN n l = runST $ do
511-
sma <- newArray n (die "fromListN" "uninitialized element")
512-
let go !ix [] = if ix == n
513-
then return ()
514-
else die "fromListN" "list length less than specified size"
515-
go !ix (x : xs) = if ix < n
516-
then do
517-
writeArray sma ix x
518-
go (ix+1) xs
519-
else die "fromListN" "list length greater than specified size"
520-
go 0 l
521-
unsafeFreezeArray sma
563+
arrayFromListN n l =
564+
createArray n (die "fromListN" "uninitialized element") $ \sma ->
565+
let go !ix [] = if ix == n
566+
then return ()
567+
else die "fromListN" "list length less than specified size"
568+
go !ix (x : xs) = if ix < n
569+
then do
570+
writeArray sma ix x
571+
go (ix+1) xs
572+
else die "fromListN" "list length greater than specified size"
573+
in go 0 l
522574

523575
arrayFromList :: [a] -> Array a
524576
arrayFromList l = arrayFromListN (length l) l
@@ -547,13 +599,12 @@ instance Functor Array where
547599
writeArray mb i (f x) >> go (i+1)
548600
in go 0
549601
#if MIN_VERSION_base(4,8,0)
550-
e <$ a = runST $ newArray (sizeofArray a) e >>= unsafeFreezeArray
602+
e <$ a = createArray (sizeofArray a) e (\ !_ -> pure ())
551603
#endif
552604

553605
instance Applicative Array where
554-
pure x = runST $ newArray 1 x >>= unsafeFreezeArray
555-
ab <*> a = runST $ do
556-
mb <- newArray (szab*sza) $ die "<*>" "impossible"
606+
pure x = runArray $ newArray 1 x
607+
ab <*> a = createArray (szab*sza) (die "<*>" "impossible") $ \mb ->
557608
let go1 i = when (i < szab) $
558609
do
559610
f <- indexArrayM ab i
@@ -564,8 +615,7 @@ instance Applicative Array where
564615
x <- indexArrayM a j
565616
writeArray mb (off + j) (f x)
566617
go2 off f (j + 1)
567-
go1 0
568-
unsafeFreezeArray mb
618+
in go1 0
569619
where szab = sizeofArray ab ; sza = sizeofArray a
570620
a *> b = createArray (sza*szb) (die "*>" "impossible") $ \mb ->
571621
let go i | i < sza = copyArray mb (i * szb) b 0 szb

Data/Primitive/SmallArray.hs

Lines changed: 68 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ module Data.Primitive.SmallArray
5151
, freezeSmallArray
5252
, unsafeFreezeSmallArray
5353
, thawSmallArray
54+
, runSmallArray
5455
, unsafeThawSmallArray
5556
, sizeofSmallArray
5657
, sizeofSmallMutableArray
@@ -78,13 +79,19 @@ import Control.Monad.Zip
7879
import Data.Data
7980
import Data.Foldable as Foldable
8081
import Data.Functor.Identity
81-
#if !(MIN_VERSION_base(4,11,0))
82+
#if !(MIN_VERSION_base(4,10,0))
8283
import Data.Monoid
8384
#endif
8485
#if MIN_VERSION_base(4,9,0)
86+
import qualified GHC.ST as GHCST
8587
import qualified Data.Semigroup as Sem
8688
#endif
8789
import Text.ParserCombinators.ReadP
90+
#if MIN_VERSION_base(4,10,0)
91+
import GHC.Exts (runRW#)
92+
#elif MIN_VERSION_base(4,9,0)
93+
import GHC.Base (runRW#)
94+
#endif
8895

8996
#if !(HAVE_SMALL_ARRAY)
9097
import Data.Primitive.Array
@@ -429,7 +436,61 @@ unsafeTraverseSmallArray f (SmallArray ar) = SmallArray `liftM` unsafeTraverseAr
429436
#endif
430437
{-# INLINE unsafeTraverseSmallArray #-}
431438

439+
#ifndef HAVE_SMALL_ARRAY
440+
runSmallArray
441+
:: (forall s. ST s (SmallMutableArray s a))
442+
-> SmallArray a
443+
runSmallArray m = SmallArray $ runArray $
444+
m >>= \(SmallMutableArray mary) -> return mary
445+
446+
#elif !MIN_VERSION_base(4,9,0)
447+
runSmallArray
448+
:: (forall s. ST s (SmallMutableArray s a))
449+
-> SmallArray a
450+
runSmallArray m = runST $ m >>= unsafeFreezeSmallArray
451+
452+
#else
453+
-- This low-level business is designed to work with GHC's worker-wrapper
454+
-- transformation. A lot of the time, we don't actually need an Array
455+
-- constructor. By putting it on the outside, and being careful about
456+
-- how we special-case the empty array, we can make GHC smarter about this.
457+
-- The only downside is that separately created 0-length arrays won't share
458+
-- their Array constructors, although they'll share their underlying
459+
-- Array#s.
460+
runSmallArray
461+
:: (forall s. ST s (SmallMutableArray s a))
462+
-> SmallArray a
463+
runSmallArray m = SmallArray (runSmallArray# m)
464+
465+
runSmallArray#
466+
:: (forall s. ST s (SmallMutableArray s a))
467+
-> SmallArray# a
468+
runSmallArray# m = case runRW# $ \s ->
469+
case unST m s of { (# s', SmallMutableArray mary# #) ->
470+
unsafeFreezeSmallArray# mary# s'} of (# _, ary# #) -> ary#
471+
472+
unST :: ST s a -> State# s -> (# State# s, a #)
473+
unST (GHCST.ST f) = f
474+
475+
#endif
476+
432477
#if HAVE_SMALL_ARRAY
478+
-- See the comment on runSmallArray for why we use emptySmallArray#.
479+
createSmallArray
480+
:: Int
481+
-> a
482+
-> (forall s. SmallMutableArray s a -> ST s ())
483+
-> SmallArray a
484+
createSmallArray 0 _ _ = SmallArray (emptySmallArray# (# #))
485+
createSmallArray n x f = runSmallArray $ do
486+
mary <- newSmallArray n x
487+
f mary
488+
pure mary
489+
490+
emptySmallArray# :: (# #) -> SmallArray# a
491+
emptySmallArray# _ = case emptySmallArray of SmallArray ar -> ar
492+
{-# NOINLINE emptySmallArray# #-}
493+
433494
die :: String -> String -> a
434495
die fun problem = error $ "Data.Primitive.SmallArray." ++ fun ++ ": " ++ problem
435496

@@ -439,12 +500,6 @@ emptySmallArray =
439500
>>= unsafeFreezeSmallArray
440501
{-# NOINLINE emptySmallArray #-}
441502

442-
createSmallArray
443-
:: Int -> a -> (forall s. SmallMutableArray s a -> ST s ()) -> SmallArray a
444-
createSmallArray 0 _ _ = emptySmallArray
445-
createSmallArray i x k =
446-
runST $ newSmallArray i x >>= \sa -> k sa *> unsafeFreezeSmallArray sa
447-
{-# INLINE createSmallArray #-}
448503

449504
infixl 1 ?
450505
(?) :: (a -> b -> c) -> (b -> a -> c)
@@ -666,8 +721,7 @@ instance Applicative SmallArray where
666721
in go 0
667722
where sza = sizeofSmallArray a ; szb = sizeofSmallArray b
668723

669-
ab <*> a = runST $ do
670-
mb <- newSmallArray (szab*sza) $ die "<*>" "impossible"
724+
ab <*> a = createSmallArray (szab*sza) (die "<*>" "impossible") $ \mb ->
671725
let go1 i = when (i < szab) $
672726
do
673727
f <- indexSmallArrayM ab i
@@ -678,8 +732,7 @@ instance Applicative SmallArray where
678732
x <- indexSmallArrayM a j
679733
writeSmallArray mb (off + j) (f x)
680734
go2 off f (j + 1)
681-
go1 0
682-
unsafeFreezeSmallArray mb
735+
in go1 0
683736
where szab = sizeofSmallArray ab ; sza = sizeofSmallArray a
684737

685738
instance Alternative SmallArray where
@@ -868,8 +921,9 @@ instance (Typeable s, Typeable a) => Data (SmallMutableArray s a) where
868921
-- of the list does not match the given length, this throws an exception.
869922
smallArrayFromListN :: Int -> [a] -> SmallArray a
870923
#if HAVE_SMALL_ARRAY
871-
smallArrayFromListN n l = runST $ do
872-
sma <- newSmallArray n (die "smallArrayFromListN" "uninitialized element")
924+
smallArrayFromListN n l =
925+
createSmallArray n
926+
(die "smallArrayFromListN" "uninitialized element") $ \sma ->
873927
let go !ix [] = if ix == n
874928
then return ()
875929
else die "smallArrayFromListN" "list length less than specified size"
@@ -878,13 +932,11 @@ smallArrayFromListN n l = runST $ do
878932
writeSmallArray sma ix x
879933
go (ix+1) xs
880934
else die "smallArrayFromListN" "list length greater than specified size"
881-
go 0 l
882-
unsafeFreezeSmallArray sma
935+
in go 0 l
883936
#else
884937
smallArrayFromListN n l = SmallArray (Array.fromListN n l)
885938
#endif
886939

887940
-- | Create a 'SmallArray' from a list.
888941
smallArrayFromList :: [a] -> SmallArray a
889942
smallArrayFromList l = smallArrayFromListN (length l) l
890-

0 commit comments

Comments
 (0)