Skip to content

Commit 8a499fb

Browse files
committed
Ensure globals are looked up in the user's module
Create method of `generated_callfunc` in the user's module so that any global symbols within the body will be looked up in the user's module scope. This is straightforward but clunky. A neater solution should be to explicitly expand in the user's module and return a CodeInfo from `generated_callfunc`, but it seems we'd need `jl_expand_and_resolve` which doesn't exist until Julia 1.3 or so. See: * JuliaLang/julia#32902 * https://github.com/NHDaly/StagedFunctions.jl/blob/master/src/StagedFunctions.jl#L30
1 parent bbfaf60 commit 8a499fb

File tree

2 files changed

+52
-8
lines changed

2 files changed

+52
-8
lines changed

src/RuntimeGeneratedFunctions.jl

Lines changed: 39 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,13 @@ then calling the resulting function. The differences are:
3232
* The result is not a named generic function, and doesn't participate in
3333
generic function dispatch; it's more like a callable method.
3434
35+
You need to use the special form `@RuntimeGeneratedFunction __init__` a single time
36+
at the top level of your module before any other uses of the macro.
37+
3538
# Examples
3639
```
40+
@RuntimeGeneratedFunction __init__ # Required once per module
41+
3742
function foo()
3843
expression = :((x,y)->x+y+1) # May be generated dynamically
3944
f = @RuntimeGeneratedFunction(expression)
@@ -42,12 +47,19 @@ end
4247
```
4348
"""
4449
macro RuntimeGeneratedFunction(ex)
45-
_ensure_cache_exists!(__module__)
46-
quote
47-
RuntimeGeneratedFunction(
48-
$(esc(_tagname)),
49-
$(esc(ex))
50-
)
50+
if ex === :__init__
51+
_init_cache!(__module__)
52+
else
53+
quote
54+
if !($(esc(:(@isdefined($_tagname)))))
55+
error("""You must use `@RuntimeGeneratedFunction __init__` at module
56+
top level before using runtime generated functions""")
57+
end
58+
RuntimeGeneratedFunction(
59+
$(esc(_tagname)),
60+
$(esc(ex))
61+
)
62+
end
5163
end
5264
end
5365

@@ -59,7 +71,11 @@ end
5971

6072
(f::RuntimeGeneratedFunction)(args::Vararg{Any,N}) where N = generated_callfunc(f, args...)
6173

62-
@inline @generated function generated_callfunc(f::RuntimeGeneratedFunction{moduletag, id, argnames}, __args...) where {moduletag,id,argnames}
74+
# We'll generate a method of this function in every module which wants to use
75+
# @RuntimeGeneratedFunction
76+
function generated_callfunc end
77+
78+
function generated_callfunc_body(moduletag, id, argnames, __args)
6379
setup = (:($(argnames[i]) = @inbounds __args[$i]) for i in 1:length(argnames))
6480
body = _lookup_body(moduletag, id)
6581
@assert body !== nothing
@@ -122,13 +138,28 @@ function _lookup_body(moduletag, id)
122138
end
123139
end
124140

125-
function _ensure_cache_exists!(mod)
141+
function _init_cache!(mod)
126142
lock(_cache_lock) do
127143
if !isdefined(mod, _cachename)
128144
mod.eval(quote
129145
const $_cachename = Dict()
130146
struct $_tagname
131147
end
148+
149+
# We create method of `generated_callfunc` in the user's module
150+
# so that any global symbols within the body will be looked up
151+
# in the user's module scope.
152+
#
153+
# This is straightforward but clunky. A neater solution should
154+
# be to explicitly expand in the user's module and return a
155+
# CodeInfo from `generated_callfunc`, but it seems we'd need
156+
# `jl_expand_and_resolve` which doesn't exist until Julia 1.3
157+
# or so. See:
158+
# https://github.com/JuliaLang/julia/pull/32902
159+
# https://github.com/NHDaly/StagedFunctions.jl/blob/master/src/StagedFunctions.jl#L30
160+
@inline @generated function $RuntimeGeneratedFunctions.generated_callfunc(f::$RuntimeGeneratedFunctions.RuntimeGeneratedFunction{$_tagname, id, argnames}, __args...) where {id,argnames}
161+
$RuntimeGeneratedFunctions.generated_callfunc_body($_tagname, id, argnames, __args)
162+
end
132163
end)
133164
end
134165
end

test/runtests.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,3 +107,16 @@ for k=1:4
107107
end
108108
@test all(all.(fetch.(tasks)))
109109

110+
111+
# Test that globals are resolved within the correct scope
112+
113+
module GlobalsTest
114+
115+
using RuntimeGeneratedFunctions
116+
117+
@RuntimeGeneratedFunction __init__
118+
y = 10
119+
f = @RuntimeGeneratedFunction(:(x->x+y))
120+
121+
end
122+

0 commit comments

Comments
 (0)