Skip to content

Commit b804a3b

Browse files
Merge pull request #37 from c42f/cjf/fix-distributed-deserialize
Ensure RGF body is added to the cache when deserializing
2 parents a0872e5 + e5e1403 commit b804a3b

File tree

3 files changed

+31
-0
lines changed

3 files changed

+31
-0
lines changed

src/RuntimeGeneratedFunctions.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,11 @@ struct RuntimeGeneratedFunction{argnames, cache_tag, context_tag, id} <: Functio
6262
cached_body = _cache_body(cache_tag, id, body)
6363
new{Tuple(args), cache_tag, context_tag, id}(cached_body)
6464
end
65+
66+
# For internal use in deserialize() - doesen't check whether the body is in the cache!
67+
function RuntimeGeneratedFunction{argnames, cache_tag, context_tag, id}(body::Expr) where {argnames,cache_tag,context_tag,id}
68+
new{argnames, cache_tag, context_tag, id}(body)
69+
end
6570
end
6671

6772
function _check_rgf_initialized(mods...)
@@ -273,6 +278,15 @@ function closures_to_opaque(ex::Expr, return_type=nothing)
273278
return Expr(head, Any[closures_to_opaque(x, return_type) for x in args]...)
274279
end
275280

281+
# We write an explicit deserialize() here to trigger caching of the body on a
282+
# remote node when using Serialialization.jl (in Distributed.jl and elsewhere)
283+
function Serialization.deserialize(s::AbstractSerializer,
284+
::Type{RuntimeGeneratedFunction{argnames, cache_tag, context_tag, id}}) where {argnames,cache_tag,context_tag,id}
285+
body = deserialize(s)
286+
cached_body = _cache_body(cache_tag, id, body)
287+
RuntimeGeneratedFunction{argnames, cache_tag, context_tag, id}(cached_body)
288+
end
289+
276290
@specialize
277291

278292
end

test/runtests.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using RuntimeGeneratedFunctions, BenchmarkTools
2+
using Serialization
23
using Test
34

45
RuntimeGeneratedFunctions.init(@__MODULE__)
@@ -158,3 +159,9 @@ if VERSION >= v"1.7.0-DEV.351"
158159
ex = :(x -> [2i for i in 1:x])
159160
@test @RuntimeGeneratedFunction(ex)(3) == [2, 4, 6]
160161
end
162+
163+
# Serialization
164+
165+
buf = IOBuffer(read(`$(Base.julia_cmd()) "serialize_rgf.jl"`))
166+
deserialized_f = deserialize(buf)
167+
@test deserialized_f(11) == "Hi from a separate process. x=11"

test/serialize_rgf.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
# Must be run in a separate process from the rest of the tests!
2+
3+
using RuntimeGeneratedFunctions
4+
using Serialization
5+
6+
RuntimeGeneratedFunctions.init(@__MODULE__)
7+
8+
f = @RuntimeGeneratedFunction(:(x->"Hi from a separate process. x=$x"))
9+
10+
serialize(stdout, f)

0 commit comments

Comments
 (0)