Skip to content

Commit dc89b87

Browse files
authored
Merge pull request #88 from treeowl/traverse-better
Traverse better
2 parents 5970879 + 90e2e59 commit dc89b87

File tree

2 files changed

+157
-68
lines changed

2 files changed

+157
-68
lines changed

Data/Primitive/Array.hs

Lines changed: 82 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ module Data.Primitive.Array (
2222
copyArray, copyMutableArray,
2323
cloneArray, cloneMutableArray,
2424
sizeofArray, sizeofMutableArray,
25-
fromListN, fromList
25+
fromListN, fromList,
26+
unsafeTraverseArray
2627
) where
2728

2829
import Control.Monad.Primitive
@@ -56,36 +57,29 @@ import Data.Monoid
5657
import qualified Data.Foldable as F
5758
import Data.Semigroup
5859
#endif
60+
#if MIN_VERSION_base(4,8,0)
61+
import Data.Functor.Identity
62+
#endif
5963

6064
import Text.ParserCombinators.ReadP
6165

6266
-- | Boxed arrays
6367
data Array a = Array
64-
{ array# :: Array# a
65-
#if (__GLASGOW_HASKELL__ < 702)
66-
, sizeofArray :: {-# UNPACK #-} !Int
67-
#endif
68-
}
68+
{ array# :: Array# a }
6969
deriving ( Typeable )
7070

7171
-- | Mutable boxed arrays associated with a primitive state token.
7272
data MutableArray s a = MutableArray
73-
{ marray# :: MutableArray# s a
74-
#if (__GLASGOW_HASKELL__ < 702)
75-
, sizeofMutableArray :: {-# UNPACK #-} !Int
76-
#endif
77-
}
73+
{ marray# :: MutableArray# s a }
7874
deriving ( Typeable )
7975

80-
#if (__GLASGOW_HASKELL__ >= 702)
8176
sizeofArray :: Array a -> Int
8277
sizeofArray a = I# (sizeofArray# (array# a))
8378
{-# INLINE sizeofArray #-}
8479

8580
sizeofMutableArray :: MutableArray s a -> Int
8681
sizeofMutableArray a = I# (sizeofMutableArray# (marray# a))
8782
{-# INLINE sizeofMutableArray #-}
88-
#endif
8983

9084
-- | Create a new mutable array of the specified size and initialise all
9185
-- elements with the given value.
@@ -95,9 +89,6 @@ newArray (I# n#) x = primitive
9589
(\s# -> case newArray# n# x s# of
9690
(# s'#, arr# #) ->
9791
let ma = MutableArray arr#
98-
#if (__GLASGOW_HASKELL__ < 702)
99-
(I# n#)
100-
#endif
10192
in (# s'# , ma #))
10293

10394
-- | Read a value from the array at the given index.
@@ -161,16 +152,9 @@ freezeArray
161152
-> Int -- ^ length
162153
-> m (Array a)
163154
{-# INLINE freezeArray #-}
164-
#if (__GLASGOW_HASKELL__ >= 702)
165155
freezeArray (MutableArray ma#) (I# off#) (I# len#) =
166156
primitive $ \s -> case freezeArray# ma# off# len# s of
167157
(# s', a# #) -> (# s', Array a# #)
168-
#else
169-
freezeArray src off len = do
170-
dst <- newArray len (die "freezeArray" "impossible")
171-
copyMutableArray dst 0 src off len
172-
unsafeFreezeArray dst
173-
#endif
174158

175159
-- | Convert a mutable array to an immutable one without copying. The
176160
-- array should not be modified after the conversion.
@@ -180,9 +164,6 @@ unsafeFreezeArray arr
180164
= primitive (\s# -> case unsafeFreezeArray# (marray# arr) s# of
181165
(# s'#, arr'# #) ->
182166
let a = Array arr'#
183-
#if (__GLASGOW_HASKELL__ < 702)
184-
(sizeofMutableArray arr)
185-
#endif
186167
in (# s'#, a #))
187168

188169
-- | Create a mutable array from a slice of an immutable array.
@@ -196,16 +177,9 @@ thawArray
196177
-> Int -- ^ length
197178
-> m (MutableArray (PrimState m) a)
198179
{-# INLINE thawArray #-}
199-
#if (__GLASGOW_HASKELL__ >= 702)
200180
thawArray (Array a#) (I# off#) (I# len#) =
201181
primitive $ \s -> case thawArray# a# off# len# s of
202182
(# s', ma# #) -> (# s', MutableArray ma# #)
203-
#else
204-
thawArray src off len = do
205-
dst <- newArray len (die "thawArray" "impossible")
206-
copyArray dst 0 src off len
207-
return dst
208-
#endif
209183

210184
-- | Convert an immutable array to an mutable one without copying. The
211185
-- immutable array should not be used after the conversion.
@@ -215,9 +189,6 @@ unsafeThawArray a
215189
= primitive (\s# -> case unsafeThawArray# (array# a) s# of
216190
(# s'#, arr'# #) ->
217191
let ma = MutableArray arr'#
218-
#if (__GLASGOW_HASKELL__ < 702)
219-
(sizeofArray a)
220-
#endif
221192
in (# s'#, ma #))
222193

223194
-- | Check whether the two arrays refer to the same memory block.
@@ -282,15 +253,8 @@ cloneArray :: Array a -- ^ source array
282253
-> Int -- ^ number of elements to copy
283254
-> Array a
284255
{-# INLINE cloneArray #-}
285-
#if __GLASGOW_HASKELL__ >= 702
286256
cloneArray (Array arr#) (I# off#) (I# len#)
287257
= case cloneArray# arr# off# len# of arr'# -> Array arr'#
288-
#else
289-
cloneArray arr off len = runST $ do
290-
marr2 <- newArray len $ die "cloneArray" "impossible"
291-
copyArray marr2 0 arr off len
292-
unsafeFreezeArray marr2
293-
#endif
294258

295259
-- | Return a newly allocated MutableArray. with the specified subrange of
296260
-- the provided MutableArray. The provided MutableArray should contain the
@@ -301,21 +265,9 @@ cloneMutableArray :: PrimMonad m
301265
-> Int -- ^ number of elements to copy
302266
-> m (MutableArray (PrimState m) a)
303267
{-# INLINE cloneMutableArray #-}
304-
#if __GLASGOW_HASKELL__ >= 702
305268
cloneMutableArray (MutableArray arr#) (I# off#) (I# len#) = primitive
306269
(\s# -> case cloneMutableArray# arr# off# len# s# of
307270
(# s'#, arr'# #) -> (# s'#, MutableArray arr'# #))
308-
#else
309-
cloneMutableArray marr off len = do
310-
marr2 <- newArray len $ die "cloneMutableArray" "impossible"
311-
let go !i !j c
312-
| c >= len = return marr2
313-
| otherwise = do
314-
b <- readArray marr i
315-
writeArray marr2 j b
316-
go (i+1) (j+1) (c+1)
317-
go off 0 0
318-
#endif
319271

320272
emptyArray :: Array a
321273
emptyArray =
@@ -444,10 +396,82 @@ instance Foldable Array where
444396
{-# INLINE product #-}
445397
#endif
446398

399+
newtype STA a = STA {_runSTA :: forall s. MutableArray# s a -> ST s (Array a)}
400+
401+
runSTA :: Int -> STA a -> Array a
402+
runSTA !sz = \ (STA m) -> runST $ newArray_ sz >>= \ ar -> m (marray# ar)
403+
{-# INLINE runSTA #-}
404+
405+
newArray_ :: Int -> ST s (MutableArray s a)
406+
newArray_ !n = newArray n badTraverseValue
407+
408+
badTraverseValue :: a
409+
badTraverseValue = die "traverse" "bad indexing"
410+
{-# NOINLINE badTraverseValue #-}
411+
447412
instance Traversable Array where
448-
traverse f a =
449-
fromListN (sizeofArray a)
450-
<$> traverse (f . indexArray a) [0 .. sizeofArray a - 1]
413+
traverse f = traverseArray f
414+
{-# INLINE traverse #-}
415+
416+
traverseArray
417+
:: Applicative f
418+
=> (a -> f b)
419+
-> Array a
420+
-> f (Array b)
421+
traverseArray f = \ !ary ->
422+
let
423+
!len = sizeofArray ary
424+
go !i
425+
| i == len = pure $ STA $ \mary -> unsafeFreezeArray (MutableArray mary)
426+
| (# x #) <- indexArray## ary i
427+
= liftA2 (\b (STA m) -> STA $ \mary ->
428+
writeArray (MutableArray mary) i b >> m mary)
429+
(f x) (go (i + 1))
430+
in if len == 0
431+
then pure emptyArray
432+
else runSTA len <$> go 0
433+
{-# INLINE [1] traverseArray #-}
434+
435+
{-# RULES
436+
"traverse/ST" forall (f :: a -> ST s b). traverseArray f =
437+
unsafeTraverseArray f
438+
"traverse/IO" forall (f :: a -> IO b). traverseArray f =
439+
unsafeTraverseArray f
440+
#-}
441+
#if MIN_VERSION_base(4,8,0)
442+
{-# RULES
443+
"traverse/Id" forall (f :: a -> Identity b). traverseArray f =
444+
(coerce :: (Array a -> Array (Identity b))
445+
-> Array a -> Identity (Array b)) (fmap f)
446+
#-}
447+
#endif
448+
449+
-- | This is the fastest, most straightforward way to traverse
450+
-- an array, but it only works correctly with a sufficiently
451+
-- "affine" 'PrimMonad' instance. In particular, it must only produce
452+
-- *one* result array. 'Control.Monad.Trans.List.ListT'-transformed
453+
-- monads, for example, will not work right at all.
454+
unsafeTraverseArray
455+
:: PrimMonad m
456+
=> (a -> m b)
457+
-> Array a
458+
-> m (Array b)
459+
unsafeTraverseArray f = \ !ary ->
460+
let
461+
!sz = sizeofArray ary
462+
go !i !mary
463+
| i == sz
464+
= unsafeFreezeArray mary
465+
| otherwise
466+
= do
467+
a <- indexArrayM ary i
468+
b <- f a
469+
writeArray mary i b
470+
go (i + 1) mary
471+
in do
472+
mary <- newArray sz badTraverseValue
473+
go 0 mary
474+
{-# INLINE unsafeTraverseArray #-}
451475

452476
#if MIN_VERSION_base(4,7,0)
453477
instance Exts.IsList (Array a) where

Data/Primitive/SmallArray.hs

Lines changed: 75 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ module Data.Primitive.SmallArray
5454
, unsafeThawSmallArray
5555
, sizeofSmallArray
5656
, sizeofSmallMutableArray
57+
, unsafeTraverseSmallArray
5758
) where
5859

5960

@@ -71,9 +72,7 @@ import Control.Monad
7172
import Control.Monad.Fix
7273
import Control.Monad.Primitive
7374
import Control.Monad.ST
74-
#if MIN_VERSION_base(4,4,0)
7575
import Control.Monad.Zip
76-
#endif
7776
import Data.Data
7877
import Data.Foldable
7978
import Data.Functor.Identity
@@ -108,9 +107,7 @@ newtype SmallArray a = SmallArray (Array a) deriving
108107
, Alternative
109108
, Monad
110109
, MonadPlus
111-
#if MIN_VERSION_base(4,4,0)
112110
, MonadZip
113-
#endif
114111
, MonadFix
115112
, Monoid
116113
, Typeable
@@ -390,6 +387,37 @@ sizeofSmallMutableArray (SmallMutableArray ma) = sizeofMutableArray ma
390387
#endif
391388
{-# INLINE sizeofSmallMutableArray #-}
392389

390+
-- | This is the fastest, most straightforward way to traverse
391+
-- an array, but it only works correctly with a sufficiently
392+
-- "affine" 'PrimMonad' instance. In particular, it must only produce
393+
-- *one* result array. 'Control.Monad.Trans.List.ListT'-transformed
394+
-- monads, for example, will not work right at all.
395+
unsafeTraverseSmallArray
396+
:: PrimMonad m
397+
=> (a -> m b)
398+
-> SmallArray a
399+
-> m (SmallArray b)
400+
#if HAVE_SMALL_ARRAY
401+
unsafeTraverseSmallArray f = \ !ary ->
402+
let
403+
!sz = sizeofSmallArray ary
404+
go !i !mary
405+
| i == sz
406+
= unsafeFreezeSmallArray mary
407+
| otherwise
408+
= do
409+
a <- indexSmallArrayM ary i
410+
b <- f a
411+
writeSmallArray mary i b
412+
go (i + 1) mary
413+
in do
414+
mary <- newSmallArray sz badTraverseValue
415+
go 0 mary
416+
#else
417+
unsafeTraverseSmallArray f (SmallArray ar) = SmallArray `liftM` unsafeTraverseArray f ar
418+
#endif
419+
{-# INLINE unsafeTraverseSmallArray #-}
420+
393421
#if HAVE_SMALL_ARRAY
394422
die :: String -> String -> a
395423
die fun problem = error $ "Data.Primitive.SmallArray." ++ fun ++ ": " ++ problem
@@ -476,7 +504,6 @@ instance Foldable SmallArray where
476504
then die "foldl1" "Empty SmallArray"
477505
else go sz
478506
{-# INLINE foldl1 #-}
479-
#if MIN_VERSION_base(4,6,0)
480507
foldr' f = \z !ary ->
481508
let
482509
go i !acc
@@ -494,8 +521,6 @@ instance Foldable SmallArray where
494521
= go (i+1) (f acc x)
495522
in go 0 z
496523
{-# INLINE foldl' #-}
497-
#endif
498-
#if MIN_VERSION_base(4,8,0)
499524
null a = sizeofSmallArray a == 0
500525
{-# INLINE null #-}
501526
length = sizeofSmallArray
@@ -523,11 +548,51 @@ instance Foldable SmallArray where
523548
{-# INLINE sum #-}
524549
product = foldl' (*) 1
525550
{-# INLINE product #-}
526-
#endif
551+
552+
newtype STA a = STA {_runSTA :: forall s. SmallMutableArray# s a -> ST s (SmallArray a)}
553+
554+
runSTA :: Int -> STA a -> SmallArray a
555+
runSTA !sz = \ (STA m) -> runST $ newSmallArray_ sz >>=
556+
\ (SmallMutableArray ar#) -> m ar#
557+
{-# INLINE runSTA #-}
558+
559+
newSmallArray_ :: Int -> ST s (SmallMutableArray s a)
560+
newSmallArray_ !n = newSmallArray n badTraverseValue
561+
562+
badTraverseValue :: a
563+
badTraverseValue = die "traverse" "bad indexing"
564+
{-# NOINLINE badTraverseValue #-}
527565

528566
instance Traversable SmallArray where
529-
traverse f sa = fromListN l <$> traverse (f . indexSmallArray sa) [0..l-1]
530-
where l = length sa
567+
traverse f = traverseSmallArray f
568+
{-# INLINE traverse #-}
569+
570+
traverseSmallArray
571+
:: Applicative f
572+
=> (a -> f b) -> SmallArray a -> f (SmallArray b)
573+
traverseSmallArray f = \ !ary ->
574+
let
575+
!len = sizeofSmallArray ary
576+
go !i
577+
| i == len
578+
= pure $ STA $ \mary -> unsafeFreezeSmallArray (SmallMutableArray mary)
579+
| (# x #) <- indexSmallArray## ary i
580+
= liftA2 (\b (STA m) -> STA $ \mary ->
581+
writeSmallArray (SmallMutableArray mary) i b >> m mary)
582+
(f x) (go (i + 1))
583+
in if len == 0
584+
then pure emptySmallArray
585+
else runSTA len <$> go 0
586+
{-# INLINE [1] traverseSmallArray #-}
587+
588+
{-# RULES
589+
"traverse/ST" forall (f :: a -> ST s b). traverseSmallArray f = unsafeTraverseSmallArray f
590+
"traverse/IO" forall (f :: a -> IO b). traverseSmallArray f = unsafeTraverseSmallArray f
591+
"traverse/Id" forall (f :: a -> Identity b). traverseSmallArray f =
592+
(coerce :: (SmallArray a -> SmallArray (Identity b))
593+
-> SmallArray a -> Identity (SmallArray b)) (fmap f)
594+
#-}
595+
531596

532597
instance Functor SmallArray where
533598
fmap f sa = createSmallArray (length sa) (die "fmap" "impossible") $ \smb ->

0 commit comments

Comments
 (0)