Skip to content

Commit 2c3fbed

Browse files
committed
Separate RGF code cache from context for variable lookup
The module for RGF code cache may be different from the context for lookup of global variables used in the function's AST. This change allows these to be separate. Also fix a problem with `show()` mime type and REPL usage — use at the REPL was finding the definition for `show(::IO, ::MIME"text/plain", ::Function)` rather than ours.
1 parent 4e5de90 commit 2c3fbed

File tree

2 files changed

+45
-26
lines changed

2 files changed

+45
-26
lines changed

src/RuntimeGeneratedFunctions.jl

Lines changed: 38 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -10,19 +10,20 @@ export @RuntimeGeneratedFunction
1010
1111
This type should be constructed via the macro @RuntimeGeneratedFunction.
1212
"""
13-
struct RuntimeGeneratedFunction{argnames,moduletag,id} <: Function
13+
struct RuntimeGeneratedFunction{argnames, cache_tag, context_tag, id} <: Function
1414
body::Expr
15-
function RuntimeGeneratedFunction(moduletag, ex)
15+
function RuntimeGeneratedFunction(cache_tag, context_tag, ex)
1616
def = splitdef(ex)
1717
args, body = normalize_args(def[:args]), def[:body]
1818
id = expr_to_id(body)
19-
cached_body = _cache_body(moduletag, id, body)
20-
new{Tuple(args),moduletag,id}(cached_body)
19+
cached_body = _cache_body(cache_tag, id, body)
20+
new{Tuple(args), cache_tag, context_tag, id}(cached_body)
2121
end
2222
end
2323

2424
"""
2525
@RuntimeGeneratedFunction(function_expression)
26+
@RuntimeGeneratedFunction(context_module, function_expression)
2627
2728
Construct a function from `function_expression` which can be called immediately
2829
without world age problems. Somewhat like using `eval(function_expression)` and
@@ -35,6 +36,10 @@ then calling the resulting function. The differences are:
3536
You need to use `RuntimeGeneratedFunctions.init(your_module)` a single time at
3637
the top level of `your_module` before any other uses of the macro.
3738
39+
If provided, `context_module` is module in which symbols within
40+
`function_expression` will be looked up. By default this is module in which
41+
`@RuntimeGeneratedFunction` is expanded.
42+
3843
# Examples
3944
```
4045
RuntimeGeneratedFunctions.init(@__MODULE__) # Required at module top-level
@@ -46,23 +51,33 @@ function foo()
4651
end
4752
```
4853
"""
49-
macro RuntimeGeneratedFunction(ex)
54+
macro RuntimeGeneratedFunction(code)
55+
_RGF_constructor_code(:(@__MODULE__), esc(code))
56+
end
57+
macro RuntimeGeneratedFunction(context_module, code)
58+
_RGF_constructor_code(esc(context_module), esc(code))
59+
end
60+
61+
function _RGF_constructor_code(context_module, code)
5062
quote
51-
if !($(esc(:(@isdefined($_tagname)))))
63+
code = $code
64+
cache_module = @__MODULE__
65+
context_module = $context_module
66+
if #==# !isdefined(cache_module, $(QuoteNode(_tagname))) ||
67+
!isdefined(context_module, $(QuoteNode(_tagname)))
68+
init_mods = unique([context_module, cache_module])
5269
error("""You must use `RuntimeGeneratedFunctions.init(@__MODULE__)` at module
53-
top level before using runtime generated functions""")
70+
top level before using runtime generated functions in $init_mods""")
5471
end
55-
RuntimeGeneratedFunction(
56-
$(esc(_tagname)),
57-
$(esc(ex))
58-
)
72+
RuntimeGeneratedFunction(cache_module.$_tagname, context_module.$_tagname, $code)
5973
end
6074
end
6175

62-
function Base.show(io::IO, f::RuntimeGeneratedFunction{argnames, moduletag, id}) where {argnames,moduletag,id}
63-
mod = parentmodule(moduletag)
76+
function Base.show(io::IO, ::MIME"text/plain", f::RuntimeGeneratedFunction{argnames, cache_tag, context_tag, id}) where {argnames,cache_tag,context_tag,id}
77+
cache_mod = parentmodule(cache_tag)
78+
context_mod = parentmodule(context_tag)
6479
func_expr = Expr(:->, Expr(:tuple, argnames...), f.body)
65-
print(io, "RuntimeGeneratedFunction(#=in $mod=#, ", repr(func_expr), ")")
80+
print(io, "RuntimeGeneratedFunction(#=in $cache_mod=#, #=using $context_mod=#, ", repr(func_expr), ")")
6681
end
6782

6883
(f::RuntimeGeneratedFunction)(args::Vararg{Any,N}) where N = generated_callfunc(f, args...)
@@ -71,9 +86,9 @@ end
7186
# @RuntimeGeneratedFunction
7287
function generated_callfunc end
7388

74-
function generated_callfunc_body(argnames, moduletag, id, __args)
89+
function generated_callfunc_body(argnames, cache_tag, id, __args)
7590
setup = (:($(argnames[i]) = @inbounds __args[$i]) for i in 1:length(argnames))
76-
body = _lookup_body(moduletag, id)
91+
body = _lookup_body(cache_tag, id)
7792
@assert body !== nothing
7893
quote
7994
$(setup...)
@@ -103,9 +118,9 @@ _cache_lock = Threads.SpinLock()
103118
_cachename = Symbol("#_RuntimeGeneratedFunctions_cache")
104119
_tagname = Symbol("#_RGF_ModTag")
105120

106-
function _cache_body(moduletag, id, body)
121+
function _cache_body(cache_tag, id, body)
107122
lock(_cache_lock) do
108-
cache = getfield(parentmodule(moduletag), _cachename)
123+
cache = getfield(parentmodule(cache_tag), _cachename)
109124
# Caching is tricky when `id` is the same for different AST instances:
110125
#
111126
# Tricky case #1: If a function body with the same `id` was cached
@@ -127,9 +142,9 @@ function _cache_body(moduletag, id, body)
127142
end
128143
end
129144

130-
function _lookup_body(moduletag, id)
145+
function _lookup_body(cache_tag, id)
131146
lock(_cache_lock) do
132-
cache = getfield(parentmodule(moduletag), _cachename)
147+
cache = getfield(parentmodule(cache_tag), _cachename)
133148
cache[id].value
134149
end
135150
end
@@ -159,8 +174,9 @@ function init(mod)
159174
# or so. See:
160175
# https://github.com/JuliaLang/julia/pull/32902
161176
# https://github.com/NHDaly/StagedFunctions.jl/blob/master/src/StagedFunctions.jl#L30
162-
@inline @generated function $RuntimeGeneratedFunctions.generated_callfunc(f::$RuntimeGeneratedFunctions.RuntimeGeneratedFunction{argnames, $_tagname, id}, __args...) where {argnames,id}
163-
$RuntimeGeneratedFunctions.generated_callfunc_body(argnames, $_tagname, id, __args)
177+
@inline @generated function $RuntimeGeneratedFunctions.generated_callfunc(
178+
f::$RuntimeGeneratedFunctions.RuntimeGeneratedFunction{argnames, cache_tag, $_tagname, id}, __args...) where {argnames, cache_tag, id}
179+
$RuntimeGeneratedFunctions.generated_callfunc_body(argnames, cache_tag, id, __args)
164180
end
165181
end)
166182
end

test/runtests.jl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -72,9 +72,9 @@ end
7272
@test no_worldage() === nothing
7373

7474
# Test show()
75-
@test sprint(show, @RuntimeGeneratedFunction(Base.remove_linenums!(:((x,y)->x+y+1)))) ==
75+
@test sprint(show, MIME"text/plain"(), @RuntimeGeneratedFunction(Base.remove_linenums!(:((x,y)->x+y+1)))) ==
7676
"""
77-
RuntimeGeneratedFunction(#=in $(@__MODULE__)=#, :((x, y)->begin
77+
RuntimeGeneratedFunction(#=in $(@__MODULE__)=#, #=using $(@__MODULE__)=#, :((x, y)->begin
7878
x + y + 1
7979
end))"""
8080

@@ -118,12 +118,15 @@ module GlobalsTest
118118
using RuntimeGeneratedFunctions
119119
RuntimeGeneratedFunctions.init(@__MODULE__)
120120

121-
y = 40
122-
f = @RuntimeGeneratedFunction(:(x->x+y))
121+
y_in_GlobalsTest = 40
122+
f = @RuntimeGeneratedFunction(:(x->x + y_in_GlobalsTest))
123123
end
124124

125125
@test GlobalsTest.f(2) == 42
126126

127+
f_outside = @RuntimeGeneratedFunction(GlobalsTest, :(x->x + y_in_GlobalsTest))
128+
@test f_outside(2) == 42
129+
127130
@test_throws ErrorException @eval(module NotInitTest
128131
using RuntimeGeneratedFunctions
129132
# RuntimeGeneratedFunctions.init(@__MODULE__) # <-- missing

0 commit comments

Comments
 (0)