From ff5c1a63735b6fde803dd597599213c5850eb8b6 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Wed, 28 Apr 2021 14:00:06 +0200 Subject: [PATCH] Implement p-norm. --- Project.toml | 6 ------ src/host/linalg.jl | 13 +++++++++++++ test/Project.toml | 6 ++++++ test/testsuite/linalg.jl | 9 ++++++++- 4 files changed, 27 insertions(+), 7 deletions(-) create mode 100644 test/Project.toml diff --git a/Project.toml b/Project.toml index 0c884664a..453500d89 100644 --- a/Project.toml +++ b/Project.toml @@ -15,9 +15,3 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" AbstractFFTs = "0.4, 0.5, 1.0" Adapt = "2.0, 3.0" julia = "1.5" - -[extras] -Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" - -[targets] -test = ["Test"] diff --git a/src/host/linalg.jl b/src/host/linalg.jl index ecc5a65fd..97f7c9e5b 100644 --- a/src/host/linalg.jl +++ b/src/host/linalg.jl @@ -194,3 +194,16 @@ end # TODO: implementation without the memory copy LinearAlgebra.permutedims!(dest::AbstractGPUArray, src::AbstractGPUArray, perm) = permutedims!(dest, src, Tuple(perm)) + + +## norm + +function LinearAlgebra.norm(v::AbstractGPUArray{T}, p::Real=2) where {T} + if p == Inf + maximum(abs.(v)) + elseif p == -Inf + minimum(abs.(v)) + else + mapreduce(x->abs(x)^p, +, v; init=zero(T))^(1/p) + end +end diff --git a/test/Project.toml b/test/Project.toml new file mode 100644 index 000000000..3bdf0b240 --- /dev/null +++ b/test/Project.toml @@ -0,0 +1,6 @@ +[deps] +Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/test/testsuite/linalg.jl b/test/testsuite/linalg.jl index 404f1d73b..0879da38a 100644 --- a/test/testsuite/linalg.jl +++ b/test/testsuite/linalg.jl @@ -1,5 +1,5 @@ @testsuite "linalg" AT->begin - @testset "adjoint and transpose" begin + @testset "adjoint and trspose" begin @test compare(adjoint, AT, rand(Float32, 32, 32)) @test compare(adjoint!, AT, rand(Float32, 32, 32), rand(Float32, 32, 32)) @test compare(transpose, AT, rand(Float32, 32, 32)) @@ -121,4 +121,11 @@ end @test compare(rmul!, AT, rand(T, a), Ref(rand(T))) @test compare(lmul!, AT, Ref(rand(T)), rand(T, b)) end + + @testset "$p-norm($sz x $T)" for sz in [(2,), (2,2), (2,2,2)], + p in Any[1, 2, 3, Inf, -Inf], + T in supported_eltypes() + range = T <: Integer ? (T(1):T(10)) : T # prevent integer overflow + @test compare(norm, AT, rand(range, sz), Ref(p)) + end end