Skip to content

Commit 0824b35

Browse files
committed
update API for summation algorithms
1 parent 424e27f commit 0824b35

File tree

2 files changed

+157
-73
lines changed

2 files changed

+157
-73
lines changed

source/mir/math/sum.d

Lines changed: 156 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@ unittest
1616
import mir.ndslice.topology: map;
1717
auto ar = [1, 1e100, 1, -1e100].sliced.map!"a * 10_000";
1818
const r = 20_000;
19-
assert(r == ar.sum!(Summation.kbn));
20-
assert(r == ar.sum!(Summation.kb2));
21-
assert(r == ar.sum!(Summation.precise));
19+
assert(r == ar.sum!"kbn");
20+
assert(r == ar.sum!"kb2");
21+
assert(r == ar.sum!"precise");
2222
}
2323

2424
///
@@ -33,8 +33,8 @@ unittest
3333
.map!(n => 1.7L.pow(n+1) - 1.7L.pow(n))
3434
;
3535
real d = 1.7L.pow(1000);
36-
assert(sum!(Summation.precise)(concatenation(ar, [-d].sliced).slicedField) == -1);
37-
assert(sum!(Summation.precise)(ar.retro, -d) == -1);
36+
assert(sum!"precise"(concatenation(ar, [-d].sliced).slicedField) == -1);
37+
assert(sum!"precise"(ar.retro, -d) == -1);
3838
}
3939

4040
/++
@@ -84,9 +84,9 @@ unittest
8484
p.rijk = [3, 4, 5, 9];
8585
r.rijk = [3, 5, 7, 13];
8686

87-
assert(r == [p, q].sum!(Summation.naive));
88-
assert(r == [p, q].sum!(Summation.pairwise));
89-
assert(r == [p, q].sum!(Summation.kahan));
87+
assert(r == [p, q].sum!"naive");
88+
assert(r == [p, q].sum!"pairwise");
89+
assert(r == [p, q].sum!"kahan");
9090
}
9191

9292
/++
@@ -96,16 +96,16 @@ unittest
9696
{
9797
cdouble[] ar = [1.0 + 2i, 2 + 3i, 3 + 4i, 4 + 5i];
9898
cdouble r = 10 + 14i;
99-
assert(r == ar.sum!(Summation.fast));
100-
assert(r == ar.sum!(Summation.naive));
101-
assert(r == ar.sum!(Summation.pairwise));
102-
assert(r == ar.sum!(Summation.kahan));
99+
assert(r == ar.sum!"fast");
100+
assert(r == ar.sum!"naive");
101+
assert(r == ar.sum!"pairwise");
102+
assert(r == ar.sum!"kahan");
103103
version(LDC) // DMD Internal error: backend/cgxmm.c 628
104104
{
105-
assert(r == ar.sum!(Summation.kbn));
106-
assert(r == ar.sum!(Summation.kb2));
105+
assert(r == ar.sum!"kbn");
106+
assert(r == ar.sum!"kb2");
107107
}
108-
assert(r == ar.sum!(Summation.precise));
108+
assert(r == ar.sum!"precise");
109109
}
110110

111111
///
@@ -149,7 +149,7 @@ nothrow @nogc unittest
149149
import mir.ndslice.topology: iota, map;
150150
import core.stdc.tgmath: pow;
151151
assert(iota(1000).map!(n => 1.7L.pow(real(n)+1) - 1.7L.pow(real(n)))
152-
.sum!(Summation.precise) == -1 + 1.7L.pow(1000.0L));
152+
.sum!"precise" == -1 + 1.7L.pow(1000.0L));
153153
}
154154

155155
/// Precise summation with output range
@@ -909,78 +909,87 @@ public:
909909
void put(Range)(Range r)
910910
if (isIterable!Range)
911911
{
912-
static if (summation == Summation.pairwise)
912+
static if (summation == Summation.pairwise && fastPairwise && isDynamicArray!Range)
913913
{
914-
import mir.ndslice.slice: isSlice;
915-
static if (fastPairwise && isDynamicArray!Range && isSlice!Range)
914+
F[registersCount] v = void;
915+
foreach (i, n; chainSeq!registersCount)
916916
{
917-
F[registersCount] v = void;
918-
foreach (i, n; chainSeq!registersCount)
917+
if (r.length >= n * 2) do
919918
{
920-
if (r.length >= n * 2) do
921-
{
922-
foreach (j; Iota!n)
923-
v[j] = cast(F) r[j];
924-
foreach (j; Iota!n)
925-
v[j] += cast(F) r[n + j];
926-
foreach (m; chainSeq!(n / 2))
927-
foreach (j; Iota!m)
928-
v[j] += v[m + j];
929-
put(v[0]);
930-
r.popFrontExactly(n * 2);
931-
}
932-
while (!i && r.length >= n * 2);
919+
foreach (j; Iota!n)
920+
v[j] = cast(F) r[j];
921+
foreach (j; Iota!n)
922+
v[j] += cast(F) r[n + j];
923+
foreach (m; chainSeq!(n / 2))
924+
foreach (j; Iota!m)
925+
v[j] += v[m + j];
926+
put(v[0]);
927+
r = r[n * 2 .. $];
933928
}
934-
if (r.length)
935-
{
936-
put(cast(F) r[0]);
937-
r.popFront;
938-
}
939-
assert(r.empty);
929+
while (!i && r.length >= n * 2);
940930
}
941-
else
931+
if (r.length)
942932
{
943-
foreach (elem; r)
944-
put(elem);
933+
put(cast(F) r[0]);
934+
r = r[1 .. $];
945935
}
936+
assert(r.length == 0);
946937
}
947938
else
948-
static if (summation == Summation.precise)
939+
static if (summation == Summation.fast)
949940
{
950-
foreach (elem; r)
951-
put(elem);
941+
static if (isComplex!T)
942+
F s0 = 0 + 0fi;
943+
else
944+
F s0 = 0;
945+
foreach (ref elem; r)
946+
s0 += elem;
947+
s += s0;
952948
}
953949
else
954-
static if (summation == Summation.kb2)
955950
{
956-
foreach (elem; r)
951+
foreach (ref elem; r)
957952
put(elem);
958953
}
959-
else
960-
static if (summation == Summation.kbn)
954+
}
955+
956+
import mir.ndslice.slice;
957+
958+
/// ditto
959+
void put(Range: Slice!(kind, packs, Iterator), SliceKind kind, size_t[] packs, Iterator)(Range r)
960+
{
961+
static if (packs.length > 1)
961962
{
962-
foreach (elem; r)
963-
put(elem);
963+
import mir.ndslice.topology: unpack;
964+
this.put(r.unpack);
964965
}
965966
else
966-
static if (summation == Summation.kahan)
967+
static if (packs[0] > 1 && kind == Contiguous)
967968
{
968-
foreach (elem; r)
969-
put(elem);
969+
import mir.ndslice.topology: flattened;
970+
this.put(r.flattened);
970971
}
971-
else static if (summation == Summation.naive)
972+
else
973+
static if (isPointer!Iterator && kind == Contiguous)
972974
{
973-
foreach (elem; r)
974-
s += elem;
975+
this.put(r.iterator[0 .. r.length]);
975976
}
976977
else
977-
static if (summation == Summation.fast)
978+
static if (summation == Summation.fast && packs[0] == 1)
978979
{
979-
foreach (elem; r)
980-
s += elem;
980+
static if (isComplex!T)
981+
F s0 = 0 + 0fi;
982+
else
983+
F s0 = 0;
984+
import mir.ndslice.algorithm: reduce;
985+
s0 = s0.reduce!"a + b"(r);
986+
s += s0;
981987
}
982988
else
983-
static assert(0);
989+
{
990+
foreach(elem; r)
991+
this.put(elem);
992+
}
984993
}
985994

986995
/+
@@ -1694,6 +1703,19 @@ template sum(Summation summation = Summation.appropriate)
16941703
}
16951704
}
16961705

1706+
///ditto
1707+
template sum(F, string summation)
1708+
if (isFloatingPoint!F && isMutable!F)
1709+
{
1710+
mixin("alias sum = .sum!(F, Summation." ~ summation ~ ");");
1711+
}
1712+
1713+
///ditto
1714+
template sum(string summation)
1715+
{
1716+
mixin("alias sum = .sum!(Summation." ~ summation ~ ");");
1717+
}
1718+
16971719

16981720
@safe pure nothrow unittest
16991721
{
@@ -1837,22 +1859,31 @@ private F sumPrecise(Range, F)(Range r, F seed = summationInitValue!F)
18371859
static if (isFloatingPoint!F)
18381860
{
18391861
auto sum = Summator!(F, Summation.precise)(seed);
1840-
for (; !r.empty; r.popFront)
1841-
{
1842-
sum.put(r.front);
1843-
}
1862+
sum.put(r);
18441863
return sum.sum;
18451864
}
18461865
else
18471866
{
18481867
alias T = typeof(F.init.re);
18491868
auto sumRe = Summator!(T, Summation.precise)(seed.re);
18501869
auto sumIm = Summator!(T, Summation.precise)(seed.im);
1851-
foreach (elem; r)
1870+
import mir.ndslice.slice: isSlice;
1871+
static if (isSlice!Range)
18521872
{
1853-
auto e = elem;
1854-
sumRe.put(e.re);
1855-
sumIm.put(e.im);
1873+
import mir.ndslice.algorithm: each;
1874+
r.each!((auto ref elem)
1875+
{
1876+
sumRe.put(elem.re);
1877+
sumIm.put(elem.im);
1878+
});
1879+
}
1880+
else
1881+
{
1882+
foreach (ref elem; r)
1883+
{
1884+
sumRe.put(elem.re);
1885+
sumIm.put(elem.im);
1886+
}
18561887
}
18571888
return sumRe.sum + sumIm.sum * 1fi;
18581889
}
@@ -1921,7 +1952,11 @@ private T summationInitValue(T)()
19211952

19221953
private template sumType(Range)
19231954
{
1924-
alias T = Unqual!(ForeachType!Range);
1955+
import mir.ndslice.slice: isSlice, DeepElementType;
1956+
static if (isSlice!Range)
1957+
alias T = Unqual!(DeepElementType!(Range.PureThis));
1958+
else
1959+
alias T = Unqual!(ForeachType!Range);
19251960
alias sumType = typeof(T.init + T.init);
19261961
}
19271962

@@ -1945,7 +1980,7 @@ template isSummable(Range, F)
19451980
{
19461981
enum bool isSummable =
19471982
isIterable!Range &&
1948-
isImplicitlyConvertible!(Unqual!(ForeachType!Range), F) &&
1983+
isImplicitlyConvertible!(sumType!Range, F) &&
19491984
isSummable!F;
19501985
}
19511986

@@ -1960,3 +1995,52 @@ private enum bool isCompesatorAlgorithm(Summation summation) =
19601995
|| summation == Summation.kb2
19611996
|| summation == Summation.kbn
19621997
|| summation == Summation.kahan;
1998+
1999+
2000+
unittest
2001+
{
2002+
import mir.ndslice;
2003+
2004+
auto p = iota([2, 3, 4, 5]);
2005+
auto a = p.as!double;
2006+
auto b = a.flattened;
2007+
auto c = a.slice;
2008+
auto d = c.flattened;
2009+
auto s = p.flattened.sum;
2010+
2011+
assert(a.sum == s);
2012+
assert(b.sum == s);
2013+
assert(c.sum == s);
2014+
assert(d.sum == s);
2015+
2016+
assert(a.canonical.sum == s);
2017+
assert(b.canonical.sum == s);
2018+
assert(c.canonical.sum == s);
2019+
assert(d.canonical.sum == s);
2020+
2021+
assert(a.universal.transposed!3.sum == s);
2022+
assert(b.universal.sum == s);
2023+
assert(c.universal.transposed!3.sum == s);
2024+
assert(d.universal.sum == s);
2025+
2026+
assert(a.pack!2.sum!"fast" == s);
2027+
assert(c.pack!2.sum!"fast" == s);
2028+
2029+
assert(a.sum!"fast" == s);
2030+
assert(b.sum!"fast" == s);
2031+
assert(c.sum!(float, "fast") == s);
2032+
assert(d.sum!"fast" == s);
2033+
2034+
assert(a.canonical.sum!"fast" == s);
2035+
assert(b.canonical.sum!"fast" == s);
2036+
assert(c.canonical.sum!"fast" == s);
2037+
assert(d.canonical.sum!"fast" == s);
2038+
2039+
assert(a.universal.transposed!3.sum!"fast" == s);
2040+
assert(b.universal.sum!"fast" == s);
2041+
assert(c.universal.transposed!3.sum!"fast" == s);
2042+
assert(d.universal.sum!"fast" == s);
2043+
2044+
assert(a.pack!2.sum!"fast" == s);
2045+
assert(c.pack!2.sum!"fast" == s);
2046+
}

source/mir/ndslice/slice.d

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -521,7 +521,7 @@ struct Slice(SliceKind kind, size_t[] packs, Iterator)
521521
{
522522
@fastmath:
523523

524-
package:
524+
package(mir):
525525

526526
///
527527
enum N = packs.sum;

0 commit comments

Comments
 (0)