Skip to content

Commit 90e2e59

Browse files
committed
Add special traversals
* Add `unsafeTraverseArray` and `unsafeTraverseSmallArray` functions. * Add rewrite rules to use them for traversals in `ST s` and `IO`. * Add rewrite rules for traversing in `Identity`.
1 parent be7619f commit 90e2e59

File tree

2 files changed

+126
-34
lines changed

2 files changed

+126
-34
lines changed

Data/Primitive/Array.hs

Lines changed: 66 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ module Data.Primitive.Array (
2222
copyArray, copyMutableArray,
2323
cloneArray, cloneMutableArray,
2424
sizeofArray, sizeofMutableArray,
25-
fromListN, fromList
25+
fromListN, fromList,
26+
unsafeTraverseArray
2627
) where
2728

2829
import Control.Monad.Primitive
@@ -56,6 +57,9 @@ import Data.Monoid
5657
import qualified Data.Foldable as F
5758
import Data.Semigroup
5859
#endif
60+
#if MIN_VERSION_base(4,8,0)
61+
import Data.Functor.Identity
62+
#endif
5963

6064
import Text.ParserCombinators.ReadP
6165

@@ -406,20 +410,69 @@ badTraverseValue = die "traverse" "bad indexing"
406410
{-# NOINLINE badTraverseValue #-}
407411

408412
instance Traversable Array where
409-
traverse f = \ !ary ->
410-
let
411-
!len = sizeofArray ary
412-
go !i
413-
| i == len = pure $ STA $ \mary -> unsafeFreezeArray (MutableArray mary)
414-
| (# x #) <- indexArray## ary i
415-
= liftA2 (\b (STA m) -> STA $ \mary ->
416-
writeArray (MutableArray mary) i b >> m mary)
417-
(f x) (go (i + 1))
418-
in if len == 0
419-
then pure emptyArray
420-
else runSTA len <$> go 0
413+
traverse f = traverseArray f
421414
{-# INLINE traverse #-}
422415

416+
traverseArray
417+
:: Applicative f
418+
=> (a -> f b)
419+
-> Array a
420+
-> f (Array b)
421+
traverseArray f = \ !ary ->
422+
let
423+
!len = sizeofArray ary
424+
go !i
425+
| i == len = pure $ STA $ \mary -> unsafeFreezeArray (MutableArray mary)
426+
| (# x #) <- indexArray## ary i
427+
= liftA2 (\b (STA m) -> STA $ \mary ->
428+
writeArray (MutableArray mary) i b >> m mary)
429+
(f x) (go (i + 1))
430+
in if len == 0
431+
then pure emptyArray
432+
else runSTA len <$> go 0
433+
{-# INLINE [1] traverseArray #-}
434+
435+
{-# RULES
436+
"traverse/ST" forall (f :: a -> ST s b). traverseArray f =
437+
unsafeTraverseArray f
438+
"traverse/IO" forall (f :: a -> IO b). traverseArray f =
439+
unsafeTraverseArray f
440+
#-}
441+
#if MIN_VERSION_base(4,8,0)
442+
{-# RULES
443+
"traverse/Id" forall (f :: a -> Identity b). traverseArray f =
444+
(coerce :: (Array a -> Array (Identity b))
445+
-> Array a -> Identity (Array b)) (fmap f)
446+
#-}
447+
#endif
448+
449+
-- | This is the fastest, most straightforward way to traverse
450+
-- an array, but it only works correctly with a sufficiently
451+
-- "affine" 'PrimMonad' instance. In particular, it must only produce
452+
-- *one* result array. 'Control.Monad.Trans.List.ListT'-transformed
453+
-- monads, for example, will not work right at all.
454+
unsafeTraverseArray
455+
:: PrimMonad m
456+
=> (a -> m b)
457+
-> Array a
458+
-> m (Array b)
459+
unsafeTraverseArray f = \ !ary ->
460+
let
461+
!sz = sizeofArray ary
462+
go !i !mary
463+
| i == sz
464+
= unsafeFreezeArray mary
465+
| otherwise
466+
= do
467+
a <- indexArrayM ary i
468+
b <- f a
469+
writeArray mary i b
470+
go (i + 1) mary
471+
in do
472+
mary <- newArray sz badTraverseValue
473+
go 0 mary
474+
{-# INLINE unsafeTraverseArray #-}
475+
423476
#if MIN_VERSION_base(4,7,0)
424477
instance Exts.IsList (Array a) where
425478
type Item (Array a) = a

Data/Primitive/SmallArray.hs

Lines changed: 60 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ module Data.Primitive.SmallArray
5454
, unsafeThawSmallArray
5555
, sizeofSmallArray
5656
, sizeofSmallMutableArray
57+
, unsafeTraverseSmallArray
5758
) where
5859

5960

@@ -71,9 +72,7 @@ import Control.Monad
7172
import Control.Monad.Fix
7273
import Control.Monad.Primitive
7374
import Control.Monad.ST
74-
#if MIN_VERSION_base(4,4,0)
7575
import Control.Monad.Zip
76-
#endif
7776
import Data.Data
7877
import Data.Foldable
7978
import Data.Functor.Identity
@@ -108,9 +107,7 @@ newtype SmallArray a = SmallArray (Array a) deriving
108107
, Alternative
109108
, Monad
110109
, MonadPlus
111-
#if MIN_VERSION_base(4,4,0)
112110
, MonadZip
113-
#endif
114111
, MonadFix
115112
, Monoid
116113
, Typeable
@@ -390,6 +387,37 @@ sizeofSmallMutableArray (SmallMutableArray ma) = sizeofMutableArray ma
390387
#endif
391388
{-# INLINE sizeofSmallMutableArray #-}
392389

390+
-- | This is the fastest, most straightforward way to traverse
391+
-- an array, but it only works correctly with a sufficiently
392+
-- "affine" 'PrimMonad' instance. In particular, it must only produce
393+
-- *one* result array. 'Control.Monad.Trans.List.ListT'-transformed
394+
-- monads, for example, will not work right at all.
395+
unsafeTraverseSmallArray
396+
:: PrimMonad m
397+
=> (a -> m b)
398+
-> SmallArray a
399+
-> m (SmallArray b)
400+
#if HAVE_SMALL_ARRAY
401+
unsafeTraverseSmallArray f = \ !ary ->
402+
let
403+
!sz = sizeofSmallArray ary
404+
go !i !mary
405+
| i == sz
406+
= unsafeFreezeSmallArray mary
407+
| otherwise
408+
= do
409+
a <- indexSmallArrayM ary i
410+
b <- f a
411+
writeSmallArray mary i b
412+
go (i + 1) mary
413+
in do
414+
mary <- newSmallArray sz badTraverseValue
415+
go 0 mary
416+
#else
417+
unsafeTraverseSmallArray f (SmallArray ar) = SmallArray `liftM` unsafeTraverseArray f ar
418+
#endif
419+
{-# INLINE unsafeTraverseSmallArray #-}
420+
393421
#if HAVE_SMALL_ARRAY
394422
die :: String -> String -> a
395423
die fun problem = error $ "Data.Primitive.SmallArray." ++ fun ++ ": " ++ problem
@@ -476,7 +504,6 @@ instance Foldable SmallArray where
476504
then die "foldl1" "Empty SmallArray"
477505
else go sz
478506
{-# INLINE foldl1 #-}
479-
#if MIN_VERSION_base(4,6,0)
480507
foldr' f = \z !ary ->
481508
let
482509
go i !acc
@@ -494,8 +521,6 @@ instance Foldable SmallArray where
494521
= go (i+1) (f acc x)
495522
in go 0 z
496523
{-# INLINE foldl' #-}
497-
#endif
498-
#if MIN_VERSION_base(4,8,0)
499524
null a = sizeofSmallArray a == 0
500525
{-# INLINE null #-}
501526
length = sizeofSmallArray
@@ -523,7 +548,6 @@ instance Foldable SmallArray where
523548
{-# INLINE sum #-}
524549
product = foldl' (*) 1
525550
{-# INLINE product #-}
526-
#endif
527551

528552
newtype STA a = STA {_runSTA :: forall s. SmallMutableArray# s a -> ST s (SmallArray a)}
529553

@@ -540,21 +564,36 @@ badTraverseValue = die "traverse" "bad indexing"
540564
{-# NOINLINE badTraverseValue #-}
541565

542566
instance Traversable SmallArray where
543-
traverse f = \ !ary ->
544-
let
545-
!len = sizeofSmallArray ary
546-
go !i
547-
| i == len
548-
= pure $ STA $ \mary -> unsafeFreezeSmallArray (SmallMutableArray mary)
549-
| (# x #) <- indexSmallArray## ary i
550-
= liftA2 (\b (STA m) -> STA $ \mary ->
551-
writeSmallArray (SmallMutableArray mary) i b >> m mary)
552-
(f x) (go (i + 1))
553-
in if len == 0
554-
then pure emptySmallArray
555-
else runSTA len <$> go 0
567+
traverse f = traverseSmallArray f
556568
{-# INLINE traverse #-}
557569

570+
traverseSmallArray
571+
:: Applicative f
572+
=> (a -> f b) -> SmallArray a -> f (SmallArray b)
573+
traverseSmallArray f = \ !ary ->
574+
let
575+
!len = sizeofSmallArray ary
576+
go !i
577+
| i == len
578+
= pure $ STA $ \mary -> unsafeFreezeSmallArray (SmallMutableArray mary)
579+
| (# x #) <- indexSmallArray## ary i
580+
= liftA2 (\b (STA m) -> STA $ \mary ->
581+
writeSmallArray (SmallMutableArray mary) i b >> m mary)
582+
(f x) (go (i + 1))
583+
in if len == 0
584+
then pure emptySmallArray
585+
else runSTA len <$> go 0
586+
{-# INLINE [1] traverseSmallArray #-}
587+
588+
{-# RULES
589+
"traverse/ST" forall (f :: a -> ST s b). traverseSmallArray f = unsafeTraverseSmallArray f
590+
"traverse/IO" forall (f :: a -> IO b). traverseSmallArray f = unsafeTraverseSmallArray f
591+
"traverse/Id" forall (f :: a -> Identity b). traverseSmallArray f =
592+
(coerce :: (SmallArray a -> SmallArray (Identity b))
593+
-> SmallArray a -> Identity (SmallArray b)) (fmap f)
594+
#-}
595+
596+
558597
instance Functor SmallArray where
559598
fmap f sa = createSmallArray (length sa) (die "fmap" "impossible") $ \smb ->
560599
fix ? 0 $ \go i ->

0 commit comments

Comments
 (0)