Skip to content

Commit ff764da

Browse files
committed
Make some things more eager
* Perform array indexing eagerly in general to avoid useless thunks. * Make `munzip` stricter for `Array`, to match the `SmallArray` instance and avoid loads of thunks. * Give `Array` and `SmallArray` much less inefficient `MonadFix` instances. Leaning on the instance for `[]` is bad because indexing into lists is expensive, and that's effectively what the `MonadFix` instance does.
1 parent bb0099f commit ff764da

File tree

2 files changed

+136
-61
lines changed

2 files changed

+136
-61
lines changed

Data/Primitive/Array.hs

Lines changed: 62 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ import Data.Primitive.Internal.Compat ( isTrue#, mkNoRepType )
4242
import Control.Monad.ST(ST,runST)
4343

4444
import Control.Applicative
45-
import Control.Monad (MonadPlus(..))
45+
import Control.Monad (MonadPlus(..), when)
4646
import Control.Monad.Fix
4747
#if MIN_VERSION_base(4,4,0)
4848
import Control.Monad.Zip
@@ -339,7 +339,9 @@ die fun problem = error $ "Data.Primitive.Array." ++ fun ++ ": " ++ problem
339339
instance Eq a => Eq (Array a) where
340340
a1 == a2 = sizeofArray a1 == sizeofArray a2 && loop (sizeofArray a1 - 1)
341341
where loop i | i < 0 = True
342-
| otherwise = indexArray a1 i == indexArray a2 i && loop (i-1)
342+
| (# x1 #) <- indexArray## a1 i
343+
, (# x2 #) <- indexArray## a2 i
344+
= x1 == x2 && loop (i-1)
343345

344346
instance Eq (MutableArray s a) where
345347
ma1 == ma2 = isTrue# (sameMutableArray# (marray# ma1) (marray# ma2))
@@ -349,7 +351,10 @@ instance Ord a => Ord (Array a) where
349351
where
350352
mn = sizeofArray a1 `min` sizeofArray a2
351353
loop i
352-
| i < mn = compare (indexArray a1 i) (indexArray a2 i) `mappend` loop (i+1)
354+
| i < mn
355+
, (# x1 #) <- indexArray## a1 i
356+
, (# x2 #) <- indexArray## a2 i
357+
= compare x1 x2 `mappend` loop (i+1)
353358
| otherwise = compare (sizeofArray a1) (sizeofArray a2)
354359

355360
instance Foldable Array where
@@ -474,9 +479,11 @@ fromList l = fromListN (length l) l
474479
instance Functor Array where
475480
fmap f a =
476481
createArray (sizeofArray a) (die "fmap" "impossible") $ \mb ->
477-
let go i | i == sizeofArray a = return ()
478-
| otherwise = writeArray mb i (f $ indexArray a i)
479-
>> go (i+1)
482+
let go i | i == sizeofArray a
483+
= return ()
484+
| otherwise
485+
= do x <- indexArrayM a i
486+
writeArray mb i (f x) >> go (i+1)
480487
in go 0
481488
#if MIN_VERSION_base(4,8,0)
482489
e <$ a = runST $ newArray (sizeofArray a) e >>= unsafeFreezeArray
@@ -486,12 +493,15 @@ instance Applicative Array where
486493
pure x = runST $ newArray 1 x >>= unsafeFreezeArray
487494
ab <*> a = runST $ do
488495
mb <- newArray (szab*sza) $ die "<*>" "impossible"
489-
let go1 i
490-
| i < szab = go2 (i*sza) (indexArray ab i) 0 >> go1 (i+1)
491-
| otherwise = return ()
492-
go2 off f j
493-
| j < sza = writeArray mb (off + j) (f $ indexArray a j)
494-
| otherwise = return ()
496+
let go1 i = when (i < szab) $
497+
do
498+
f <- indexArrayM ab i
499+
go2 (i*sza) f 0
500+
go1 (i+1)
501+
go2 off f j = when (j < sza) $
502+
do
503+
x <- indexArrayM a j
504+
writeArray mb (off + j) (f x)
495505
go1 0
496506
unsafeFreezeArray mb
497507
where szab = sizeofArray ab ; sza = sizeofArray a
@@ -503,7 +513,9 @@ instance Applicative Array where
503513
a <* b = createArray (sza*szb) (die "<*" "impossible") $ \ma ->
504514
let fill off i e | i < szb = writeArray ma (off+i) e >> fill off (i+1) e
505515
| otherwise = return ()
506-
go i | i < sza = fill (i*szb) 0 (indexArray a i) >> go (i+1)
516+
go i | i < sza
517+
= do x <- indexArrayM a i
518+
fill (i*szb) 0 x >> go (i+1)
507519
| otherwise = return ()
508520
in go 0
509521
where sza = sizeofArray a ; szb = sizeofArray b
@@ -518,22 +530,35 @@ instance Alternative Array where
518530
many a | sizeofArray a == 0 = pure []
519531
| otherwise = die "many" "infinite arrays are not well defined"
520532

533+
data ArrayStack a
534+
= PushArray !(Array a) !(ArrayStack a)
535+
| EmptyStack
536+
-- See the note in SmallArray about how we might improve this.
537+
521538
instance Monad Array where
522539
return = pure
523540
(>>) = (*>)
524541

525-
ary >>= f = collect 0 [] (la-1)
542+
ary >>= f = collect 0 EmptyStack (la-1)
526543
where
527544
la = sizeofArray ary
528545
collect sz stk i
529546
| i < 0 = createArray sz (die ">>=" "impossible") $ fill 0 stk
530-
| otherwise = let sb = f $ indexArray ary i in
531-
collect (sz + sizeofArray sb) (sb:stk) (i-1)
532-
533-
fill _ [ ] _ = return ()
534-
fill off (sb:sbs) smb =
535-
copyArray smb off sb 0 (sizeofArray sb)
536-
*> fill (off + sizeofArray sb) sbs smb
547+
| (# x #) <- indexArray## ary i
548+
, let sb = f x
549+
lsb = sizeofArray sb
550+
-- If we don't perform this check, we could end up allocating
551+
-- a stack full of empty arrays if someone is filtering most
552+
-- things out. So we refrain from pushing empty arrays.
553+
= if lsb == 0
554+
then collect sz stk (i - 1)
555+
else collect (sz + lsb) (PushArray sb stk) (i-1)
556+
557+
fill _ EmptyStack _ = return ()
558+
fill off (PushArray sb sbs) smb
559+
| let lsb = sizeofArray sb
560+
= copyArray smb off sb 0 (lsb)
561+
*> fill (off + lsb) sbs smb
537562

538563
fail _ = empty
539564

@@ -543,10 +568,13 @@ instance MonadPlus Array where
543568

544569
zipW :: String -> (a -> b -> c) -> Array a -> Array b -> Array c
545570
zipW s f aa ab = createArray mn (die s "impossible") $ \mc ->
546-
let go i
547-
| i < mn = writeArray mc i (f (indexArray aa i) (indexArray ab i))
548-
>> go (i+1)
549-
| otherwise = return ()
571+
let go i | i < mn
572+
= do
573+
x <- indexArrayM aa i
574+
y <- indexArrayM ab i
575+
writeArray mc i (f x y)
576+
go (i+1)
577+
| otherwise = return ()
550578
in go 0
551579
where mn = sizeofArray aa `min` sizeofArray ab
552580
{-# INLINE zipW #-}
@@ -560,7 +588,7 @@ instance MonadZip Array where
560588
ma <- newArray sz (die "munzip" "impossible")
561589
mb <- newArray sz (die "munzip" "impossible")
562590
let go i | i < sz = do
563-
let (a, b) = indexArray aab i
591+
(a, b) <- indexArrayM aab i
564592
writeArray ma i a
565593
writeArray mb i b
566594
go (i+1)
@@ -570,7 +598,14 @@ instance MonadZip Array where
570598
#endif
571599

572600
instance MonadFix Array where
573-
mfix f = let l = mfix (toList . f) in fromListN (length l) l
601+
mfix f = createArray (sizeofArray (f err))
602+
(die "mfix" "impossible") $ flip fix 0 $
603+
\r !i !mary -> when (i < sz) $ do
604+
writeArray mary i (fix (\xi -> f xi `indexArray` i))
605+
r (i + 1) mary
606+
where
607+
sz = sizeofArray (f err)
608+
err = error "mfix for Data.Primitive.Array applied to strict function."
574609

575610
#if MIN_VERSION_base(4,9,0)
576611
instance Semigroup (Array a) where

Data/Primitive/SmallArray.hs

Lines changed: 74 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ import Control.Monad.ST
7575
import Control.Monad.Zip
7676
#endif
7777
import Data.Data
78-
import Data.Foldable
78+
import Data.Foldable as Foldable
7979
import Data.Functor.Identity
8080
#if !(MIN_VERSION_base(4,11,0))
8181
import Data.Monoid
@@ -121,7 +121,7 @@ instance IsList (SmallArray a) where
121121
type Item (SmallArray a) = a
122122
fromListN n l = SmallArray (fromListN n l)
123123
fromList l = SmallArray (fromList l)
124-
toList (SmallArray a) = toList a
124+
toList a = Foldable.toList a
125125
#endif
126126
#endif
127127

@@ -419,19 +419,27 @@ instance Eq a => Eq (SmallArray a) where
419419
sa1 == sa2 = length sa1 == length sa2 && loop (length sa1 - 1)
420420
where
421421
loop i
422-
| i < 0 = True
423-
| otherwise = indexSmallArray sa1 i == indexSmallArray sa2 i && loop (i-1)
422+
| i < 0
423+
= True
424+
| (# x #) <- indexSmallArray## sa1 i
425+
, (# y #) <- indexSmallArray## sa2 i
426+
= x == y && loop (i-1)
424427

425428
instance Eq (SmallMutableArray s a) where
426429
SmallMutableArray sma1# == SmallMutableArray sma2# =
427430
isTrue# (sameSmallMutableArray# sma1# sma2#)
428431

429432
instance Ord a => Ord (SmallArray a) where
430-
compare sl sr = fix ? 0 $ \go i ->
431-
if i < l
432-
then compare (indexSmallArray sl i) (indexSmallArray sr i) <> go (i+1)
433-
else compare (length sl) (length sr)
434-
where l = length sl `min` length sr
433+
compare a1 a2 = loop 0
434+
where
435+
mn = length a1 `min` length a2
436+
loop i
437+
| i < mn
438+
, (# x1 #) <- indexSmallArray## a1 i
439+
, (# x2 #) <- indexSmallArray## a2 i
440+
= compare x1 x2 `mappend` loop (i+1)
441+
| otherwise = compare (length a1) (length a2)
442+
435443

436444
instance Foldable SmallArray where
437445
-- Note: we perform the array lookups eagerly so we won't
@@ -532,8 +540,9 @@ instance Traversable SmallArray where
532540
instance Functor SmallArray where
533541
fmap f sa = createSmallArray (length sa) (die "fmap" "impossible") $ \smb ->
534542
fix ? 0 $ \go i ->
535-
when (i < length sa) $
536-
writeSmallArray smb i (f $ indexSmallArray sa i) *> go (i+1)
543+
when (i < length sa) $ do
544+
x <- indexSmallArrayM sa i
545+
writeSmallArray smb i (f x) *> go (i+1)
537546
{-# INLINE fmap #-}
538547

539548
x <$ sa = createSmallArray (length sa) x noOp
@@ -548,22 +557,23 @@ instance Applicative SmallArray where
548557
where
549558
la = length sa ; lb = length sb
550559

551-
sa <* sb = createSmallArray (la*lb) (indexSmallArray sa $ la-1) $ \sma ->
552-
fix ? 0 $ \outer i -> when (i < la-1) $ do
553-
let a = indexSmallArray sa i
554-
fix ? 0 $ \inner j ->
555-
when (j < lb) $
556-
writeSmallArray sma (la*i + j) a *> inner (j+1)
557-
outer $ i+1
558-
where
559-
la = length sa ; lb = length sb
560+
a <* b = createSmallArray (sza*szb) (die "<*" "impossible") $ \ma ->
561+
let fill off i e = when (i < szb) $
562+
writeSmallArray ma (off+i) e >> fill off (i+1) e
563+
go i = when (i < sza) $ do
564+
x <- indexSmallArrayM a i
565+
fill (i*szb) 0 x
566+
go (i+1)
567+
in go 0
568+
where sza = sizeofSmallArray a ; szb = sizeofSmallArray b
560569

561570
sf <*> sx = createSmallArray (lf*lx) (die "<*>" "impossible") $ \smb ->
562571
fix ? 0 $ \outer i -> when (i < lf) $ do
563-
let f = indexSmallArray sf i
572+
f <- indexSmallArrayM sf i
564573
fix ? 0 $ \inner j ->
565-
when (j < lx) $
566-
writeSmallArray smb (lf*i + j) (f $ indexSmallArray sx j)
574+
when (j < lx) $ do
575+
x <- indexSmallArrayM sx j
576+
writeSmallArray smb (lf*i + j) (f x)
567577
*> inner (j+1)
568578
outer $ i+1
569579
where
@@ -583,20 +593,41 @@ instance Alternative SmallArray where
583593
some sa | null sa = emptySmallArray
584594
| otherwise = die "some" "infinite arrays are not well defined"
585595

596+
data ArrayStack a
597+
= PushArray !(SmallArray a) !(ArrayStack a)
598+
| EmptyStack
599+
-- TODO: This isn't terribly efficient. It would be better to wrap
600+
-- ArrayStack with a type like
601+
--
602+
-- data NES s a = NES !Int !(SmallMutableArray s a) !(ArrayStack a)
603+
--
604+
-- We'd copy incoming arrays into the mutable array until we would
605+
-- overflow it. Then we'd freeze it, push it on the stack, and continue.
606+
-- Any sufficiently large incoming arrays would go straight on the stack.
607+
-- Such a scheme would make the stack much more compact in the case
608+
-- of many small arrays.
609+
586610
instance Monad SmallArray where
587611
return = pure
588612
(>>) = (*>)
589613

590-
sa >>= f = collect 0 [] (la-1)
614+
sa >>= f = collect 0 EmptyStack (la-1)
591615
where
592616
la = length sa
593617
collect sz stk i
594618
| i < 0 = createSmallArray sz (die ">>=" "impossible") $ fill 0 stk
595-
| otherwise = let sb = f $ indexSmallArray sa i in
596-
collect (sz + length sb) (sb:stk) (i-1)
597-
598-
fill _ [ ] _ = return ()
599-
fill off (sb:sbs) smb =
619+
| (# x #) <- indexSmallArray## sa i
620+
, let sb = f x
621+
lsb = length sb
622+
-- If we don't perform this check, we could end up allocating
623+
-- a stack full of empty arrays if someone is filtering most
624+
-- things out. So we refrain from pushing empty arrays.
625+
= if lsb == 0
626+
then collect sz stk (i-1)
627+
else collect (sz + lsb) (PushArray sb stk) (i-1)
628+
629+
fill _ EmptyStack _ = return ()
630+
fill off (PushArray sb sbs) smb =
600631
copySmallArray smb off sb 0 (length sb)
601632
*> fill (off + length sb) sbs smb
602633

@@ -609,9 +640,11 @@ instance MonadPlus SmallArray where
609640
zipW :: String -> (a -> b -> c) -> SmallArray a -> SmallArray b -> SmallArray c
610641
zipW nm = \f sa sb -> let mn = length sa `min` length sb in
611642
createSmallArray mn (die nm "impossible") $ \mc ->
612-
fix ? 0 $ \go i -> when (i < mn) $
613-
writeSmallArray mc i (f (indexSmallArray sa i) (indexSmallArray sb i))
614-
*> go (i+1)
643+
fix ? 0 $ \go i -> when (i < mn) $ do
644+
x <- indexSmallArrayM sa i
645+
y <- indexSmallArrayM sb i
646+
writeSmallArray mc i (f x y)
647+
go (i+1)
615648
{-# INLINE zipW #-}
616649

617650
instance MonadZip SmallArray where
@@ -631,7 +664,14 @@ instance MonadZip SmallArray where
631664
<*> unsafeFreezeSmallArray smb
632665

633666
instance MonadFix SmallArray where
634-
mfix f = fromList . mfix $ toList . f
667+
mfix f = createSmallArray (sizeofSmallArray (f err))
668+
(die "mfix" "impossible") $ flip fix 0 $
669+
\r !i !mary -> when (i < sz) $ do
670+
writeSmallArray mary i (fix (\xi -> f xi `indexSmallArray` i))
671+
r (i + 1) mary
672+
where
673+
sz = sizeofSmallArray (f err)
674+
err = error "mfix for Data.Primitive.SmallArray applied to strict function."
635675

636676
#if MIN_VERSION_base(4,9,0)
637677
instance Sem.Semigroup (SmallArray a) where
@@ -658,7 +698,7 @@ instance IsList (SmallArray a) where
658698
[] -> pure ()
659699
x:xs -> writeSmallArray sma i x *> go (i+1) xs
660700
fromList l = fromListN (length l) l
661-
toList sa = indexSmallArray sa <$> [0 .. length sa - 1]
701+
toList = Foldable.toList
662702

663703
instance Show a => Show (SmallArray a) where
664704
showsPrec p sa = showParen (p > 10) $

0 commit comments

Comments
 (0)