Skip to content

Commit ea0734a

Browse files
authored
Merge pull request #96 from treeowl/more-eagerness
More eagerness
2 parents dc89b87 + ff764da commit ea0734a

File tree

2 files changed

+141
-63
lines changed

2 files changed

+141
-63
lines changed

Data/Primitive/Array.hs

Lines changed: 67 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ import Data.Primitive.Internal.Compat ( isTrue#, mkNoRepType )
4343
import Control.Monad.ST(ST,runST)
4444

4545
import Control.Applicative
46-
import Control.Monad (MonadPlus(..))
46+
import Control.Monad (MonadPlus(..), when)
4747
import Control.Monad.Fix
4848
#if MIN_VERSION_base(4,4,0)
4949
import Control.Monad.Zip
@@ -291,7 +291,9 @@ die fun problem = error $ "Data.Primitive.Array." ++ fun ++ ": " ++ problem
291291
instance Eq a => Eq (Array a) where
292292
a1 == a2 = sizeofArray a1 == sizeofArray a2 && loop (sizeofArray a1 - 1)
293293
where loop i | i < 0 = True
294-
| otherwise = indexArray a1 i == indexArray a2 i && loop (i-1)
294+
| (# x1 #) <- indexArray## a1 i
295+
, (# x2 #) <- indexArray## a2 i
296+
= x1 == x2 && loop (i-1)
295297

296298
instance Eq (MutableArray s a) where
297299
ma1 == ma2 = isTrue# (sameMutableArray# (marray# ma1) (marray# ma2))
@@ -301,7 +303,10 @@ instance Ord a => Ord (Array a) where
301303
where
302304
mn = sizeofArray a1 `min` sizeofArray a2
303305
loop i
304-
| i < mn = compare (indexArray a1 i) (indexArray a2 i) `mappend` loop (i+1)
306+
| i < mn
307+
, (# x1 #) <- indexArray## a1 i
308+
, (# x2 #) <- indexArray## a2 i
309+
= compare x1 x2 `mappend` loop (i+1)
305310
| otherwise = compare (sizeofArray a1) (sizeofArray a2)
306311

307312
instance Foldable Array where
@@ -498,9 +503,11 @@ fromList l = fromListN (length l) l
498503
instance Functor Array where
499504
fmap f a =
500505
createArray (sizeofArray a) (die "fmap" "impossible") $ \mb ->
501-
let go i | i < sizeofArray a = return ()
502-
| otherwise = writeArray mb i (f $ indexArray a i)
503-
>> go (i+1)
506+
let go i | i == sizeofArray a
507+
= return ()
508+
| otherwise
509+
= do x <- indexArrayM a i
510+
writeArray mb i (f x) >> go (i+1)
504511
in go 0
505512
#if MIN_VERSION_base(4,8,0)
506513
e <$ a = runST $ newArray (sizeofArray a) e >>= unsafeFreezeArray
@@ -510,12 +517,15 @@ instance Applicative Array where
510517
pure x = runST $ newArray 1 x >>= unsafeFreezeArray
511518
ab <*> a = runST $ do
512519
mb <- newArray (szab*sza) $ die "<*>" "impossible"
513-
let go1 i
514-
| i < szab = go2 (i*sza) (indexArray ab i) 0 >> go1 (i+1)
515-
| otherwise = return ()
516-
go2 off f j
517-
| j < sza = writeArray mb (off + j) (f $ indexArray a j)
518-
| otherwise = return ()
520+
let go1 i = when (i < szab) $
521+
do
522+
f <- indexArrayM ab i
523+
go2 (i*sza) f 0
524+
go1 (i+1)
525+
go2 off f j = when (j < sza) $
526+
do
527+
x <- indexArrayM a j
528+
writeArray mb (off + j) (f x)
519529
go1 0
520530
unsafeFreezeArray mb
521531
where szab = sizeofArray ab ; sza = sizeofArray a
@@ -527,7 +537,9 @@ instance Applicative Array where
527537
a <* b = createArray (sza*szb) (die "<*" "impossible") $ \ma ->
528538
let fill off i e | i < szb = writeArray ma (off+i) e >> fill off (i+1) e
529539
| otherwise = return ()
530-
go i | i < sza = fill (i*szb) 0 (indexArray a i) >> go (i+1)
540+
go i | i < sza
541+
= do x <- indexArrayM a i
542+
fill (i*szb) 0 x >> go (i+1)
531543
| otherwise = return ()
532544
in go 0
533545
where sza = sizeofArray a ; szb = sizeofArray b
@@ -542,20 +554,36 @@ instance Alternative Array where
542554
many a | sizeofArray a == 0 = pure []
543555
| otherwise = die "many" "infinite arrays are not well defined"
544556

557+
data ArrayStack a
558+
= PushArray !(Array a) !(ArrayStack a)
559+
| EmptyStack
560+
-- See the note in SmallArray about how we might improve this.
561+
545562
instance Monad Array where
546563
return = pure
547564
(>>) = (*>)
548-
a >>= f = push 0 [] (sizeofArray a - 1)
565+
566+
ary >>= f = collect 0 EmptyStack (la-1)
549567
where
550-
push !sz bs i
551-
| i < 0 = build sz bs
552-
| otherwise = let b = f $ indexArray a i
553-
in push (sz + sizeofArray b) (b:bs) (i+1)
554-
555-
build sz stk = createArray sz (die ">>=" "impossible") $ \mb ->
556-
let go off (b:bs) = copyArray mb off b 0 (sizeofArray b) >> go (off + sizeofArray b) bs
557-
go _ [ ] = return ()
558-
in go 0 stk
568+
la = sizeofArray ary
569+
collect sz stk i
570+
| i < 0 = createArray sz (die ">>=" "impossible") $ fill 0 stk
571+
| (# x #) <- indexArray## ary i
572+
, let sb = f x
573+
lsb = sizeofArray sb
574+
-- If we don't perform this check, we could end up allocating
575+
-- a stack full of empty arrays if someone is filtering most
576+
-- things out. So we refrain from pushing empty arrays.
577+
= if lsb == 0
578+
then collect sz stk (i - 1)
579+
else collect (sz + lsb) (PushArray sb stk) (i-1)
580+
581+
fill _ EmptyStack _ = return ()
582+
fill off (PushArray sb sbs) smb
583+
| let lsb = sizeofArray sb
584+
= copyArray smb off sb 0 (lsb)
585+
*> fill (off + lsb) sbs smb
586+
559587
fail _ = empty
560588

561589
instance MonadPlus Array where
@@ -564,10 +592,13 @@ instance MonadPlus Array where
564592

565593
zipW :: String -> (a -> b -> c) -> Array a -> Array b -> Array c
566594
zipW s f aa ab = createArray mn (die s "impossible") $ \mc ->
567-
let go i
568-
| i < mn = writeArray mc i (f (indexArray aa i) (indexArray ab i))
569-
>> go (i+1)
570-
| otherwise = return ()
595+
let go i | i < mn
596+
= do
597+
x <- indexArrayM aa i
598+
y <- indexArrayM ab i
599+
writeArray mc i (f x y)
600+
go (i+1)
601+
| otherwise = return ()
571602
in go 0
572603
where mn = sizeofArray aa `min` sizeofArray ab
573604
{-# INLINE zipW #-}
@@ -581,7 +612,7 @@ instance MonadZip Array where
581612
ma <- newArray sz (die "munzip" "impossible")
582613
mb <- newArray sz (die "munzip" "impossible")
583614
let go i | i < sz = do
584-
let (a, b) = indexArray aab i
615+
(a, b) <- indexArrayM aab i
585616
writeArray ma i a
586617
writeArray mb i b
587618
go (i+1)
@@ -591,7 +622,14 @@ instance MonadZip Array where
591622
#endif
592623

593624
instance MonadFix Array where
594-
mfix f = let l = mfix (toList . f) in fromListN (length l) l
625+
mfix f = createArray (sizeofArray (f err))
626+
(die "mfix" "impossible") $ flip fix 0 $
627+
\r !i !mary -> when (i < sz) $ do
628+
writeArray mary i (fix (\xi -> f xi `indexArray` i))
629+
r (i + 1) mary
630+
where
631+
sz = sizeofArray (f err)
632+
err = error "mfix for Data.Primitive.Array applied to strict function."
595633

596634
#if MIN_VERSION_base(4,9,0)
597635
instance Semigroup (Array a) where

Data/Primitive/SmallArray.hs

Lines changed: 74 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ import Control.Monad.Primitive
7474
import Control.Monad.ST
7575
import Control.Monad.Zip
7676
import Data.Data
77-
import Data.Foldable
77+
import Data.Foldable as Foldable
7878
import Data.Functor.Identity
7979
#if !(MIN_VERSION_base(4,11,0))
8080
import Data.Monoid
@@ -118,7 +118,7 @@ instance IsList (SmallArray a) where
118118
type Item (SmallArray a) = a
119119
fromListN n l = SmallArray (fromListN n l)
120120
fromList l = SmallArray (fromList l)
121-
toList (SmallArray a) = toList a
121+
toList a = Foldable.toList a
122122
#endif
123123
#endif
124124

@@ -447,19 +447,27 @@ instance Eq a => Eq (SmallArray a) where
447447
sa1 == sa2 = length sa1 == length sa2 && loop (length sa1 - 1)
448448
where
449449
loop i
450-
| i < 0 = True
451-
| otherwise = indexSmallArray sa1 i == indexSmallArray sa2 i && loop (i-1)
450+
| i < 0
451+
= True
452+
| (# x #) <- indexSmallArray## sa1 i
453+
, (# y #) <- indexSmallArray## sa2 i
454+
= x == y && loop (i-1)
452455

453456
instance Eq (SmallMutableArray s a) where
454457
SmallMutableArray sma1# == SmallMutableArray sma2# =
455458
isTrue# (sameSmallMutableArray# sma1# sma2#)
456459

457460
instance Ord a => Ord (SmallArray a) where
458-
compare sl sr = fix ? 0 $ \go i ->
459-
if i < l
460-
then compare (indexSmallArray sl i) (indexSmallArray sr i) <> go (i+1)
461-
else compare (length sl) (length sr)
462-
where l = length sl `min` length sr
461+
compare a1 a2 = loop 0
462+
where
463+
mn = length a1 `min` length a2
464+
loop i
465+
| i < mn
466+
, (# x1 #) <- indexSmallArray## a1 i
467+
, (# x2 #) <- indexSmallArray## a2 i
468+
= compare x1 x2 `mappend` loop (i+1)
469+
| otherwise = compare (length a1) (length a2)
470+
463471

464472
instance Foldable SmallArray where
465473
-- Note: we perform the array lookups eagerly so we won't
@@ -597,8 +605,9 @@ traverseSmallArray f = \ !ary ->
597605
instance Functor SmallArray where
598606
fmap f sa = createSmallArray (length sa) (die "fmap" "impossible") $ \smb ->
599607
fix ? 0 $ \go i ->
600-
when (i < length sa) $
601-
writeSmallArray smb i (f $ indexSmallArray sa i) *> go (i+1)
608+
when (i < length sa) $ do
609+
x <- indexSmallArrayM sa i
610+
writeSmallArray smb i (f x) *> go (i+1)
602611
{-# INLINE fmap #-}
603612

604613
x <$ sa = createSmallArray (length sa) x noOp
@@ -613,22 +622,23 @@ instance Applicative SmallArray where
613622
where
614623
la = length sa ; lb = length sb
615624

616-
sa <* sb = createSmallArray (la*lb) (indexSmallArray sa $ la-1) $ \sma ->
617-
fix ? 0 $ \outer i -> when (i < la-1) $ do
618-
let a = indexSmallArray sa i
619-
fix ? 0 $ \inner j ->
620-
when (j < lb) $
621-
writeSmallArray sma (la*i + j) a *> inner (j+1)
622-
outer $ i+1
623-
where
624-
la = length sa ; lb = length sb
625+
a <* b = createSmallArray (sza*szb) (die "<*" "impossible") $ \ma ->
626+
let fill off i e = when (i < szb) $
627+
writeSmallArray ma (off+i) e >> fill off (i+1) e
628+
go i = when (i < sza) $ do
629+
x <- indexSmallArrayM a i
630+
fill (i*szb) 0 x
631+
go (i+1)
632+
in go 0
633+
where sza = sizeofSmallArray a ; szb = sizeofSmallArray b
625634

626635
sf <*> sx = createSmallArray (lf*lx) (die "<*>" "impossible") $ \smb ->
627636
fix ? 0 $ \outer i -> when (i < lf) $ do
628-
let f = indexSmallArray sf i
637+
f <- indexSmallArrayM sf i
629638
fix ? 0 $ \inner j ->
630-
when (j < lx) $
631-
writeSmallArray smb (lf*i + j) (f $ indexSmallArray sx j)
639+
when (j < lx) $ do
640+
x <- indexSmallArrayM sx j
641+
writeSmallArray smb (lf*i + j) (f x)
632642
*> inner (j+1)
633643
outer $ i+1
634644
where
@@ -648,20 +658,41 @@ instance Alternative SmallArray where
648658
some sa | null sa = emptySmallArray
649659
| otherwise = die "some" "infinite arrays are not well defined"
650660

661+
data ArrayStack a
662+
= PushArray !(SmallArray a) !(ArrayStack a)
663+
| EmptyStack
664+
-- TODO: This isn't terribly efficient. It would be better to wrap
665+
-- ArrayStack with a type like
666+
--
667+
-- data NES s a = NES !Int !(SmallMutableArray s a) !(ArrayStack a)
668+
--
669+
-- We'd copy incoming arrays into the mutable array until we would
670+
-- overflow it. Then we'd freeze it, push it on the stack, and continue.
671+
-- Any sufficiently large incoming arrays would go straight on the stack.
672+
-- Such a scheme would make the stack much more compact in the case
673+
-- of many small arrays.
674+
651675
instance Monad SmallArray where
652676
return = pure
653677
(>>) = (*>)
654678

655-
sa >>= f = collect 0 [] (la-1)
679+
sa >>= f = collect 0 EmptyStack (la-1)
656680
where
657681
la = length sa
658682
collect sz stk i
659683
| i < 0 = createSmallArray sz (die ">>=" "impossible") $ fill 0 stk
660-
| otherwise = let sb = f $ indexSmallArray sa i in
661-
collect (sz + length sb) (sb:stk) (i-1)
662-
663-
fill _ [ ] _ = return ()
664-
fill off (sb:sbs) smb =
684+
| (# x #) <- indexSmallArray## sa i
685+
, let sb = f x
686+
lsb = length sb
687+
-- If we don't perform this check, we could end up allocating
688+
-- a stack full of empty arrays if someone is filtering most
689+
-- things out. So we refrain from pushing empty arrays.
690+
= if lsb == 0
691+
then collect sz stk (i-1)
692+
else collect (sz + lsb) (PushArray sb stk) (i-1)
693+
694+
fill _ EmptyStack _ = return ()
695+
fill off (PushArray sb sbs) smb =
665696
copySmallArray smb off sb 0 (length sb)
666697
*> fill (off + length sb) sbs smb
667698

@@ -674,9 +705,11 @@ instance MonadPlus SmallArray where
674705
zipW :: String -> (a -> b -> c) -> SmallArray a -> SmallArray b -> SmallArray c
675706
zipW nm = \f sa sb -> let mn = length sa `min` length sb in
676707
createSmallArray mn (die nm "impossible") $ \mc ->
677-
fix ? 0 $ \go i -> when (i < mn) $
678-
writeSmallArray mc i (f (indexSmallArray sa i) (indexSmallArray sb i))
679-
*> go (i+1)
708+
fix ? 0 $ \go i -> when (i < mn) $ do
709+
x <- indexSmallArrayM sa i
710+
y <- indexSmallArrayM sb i
711+
writeSmallArray mc i (f x y)
712+
go (i+1)
680713
{-# INLINE zipW #-}
681714

682715
instance MonadZip SmallArray where
@@ -696,7 +729,14 @@ instance MonadZip SmallArray where
696729
<*> unsafeFreezeSmallArray smb
697730

698731
instance MonadFix SmallArray where
699-
mfix f = fromList . mfix $ toList . f
732+
mfix f = createSmallArray (sizeofSmallArray (f err))
733+
(die "mfix" "impossible") $ flip fix 0 $
734+
\r !i !mary -> when (i < sz) $ do
735+
writeSmallArray mary i (fix (\xi -> f xi `indexSmallArray` i))
736+
r (i + 1) mary
737+
where
738+
sz = sizeofSmallArray (f err)
739+
err = error "mfix for Data.Primitive.SmallArray applied to strict function."
700740

701741
#if MIN_VERSION_base(4,9,0)
702742
instance Sem.Semigroup (SmallArray a) where
@@ -723,7 +763,7 @@ instance IsList (SmallArray a) where
723763
[] -> pure ()
724764
x:xs -> writeSmallArray sma i x *> go (i+1) xs
725765
fromList l = fromListN (length l) l
726-
toList sa = indexSmallArray sa <$> [0 .. length sa - 1]
766+
toList = Foldable.toList
727767

728768
instance Show a => Show (SmallArray a) where
729769
showsPrec p sa = showParen (p > 10) $

0 commit comments

Comments
 (0)