Skip to content

Commit e745e3f

Browse files
committed
feat: add functional version of nif for easier @generated
1 parent 3679d50 commit e745e3f

File tree

3 files changed

+101
-0
lines changed

3 files changed

+101
-0
lines changed

src/Utils.jl

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,4 +153,55 @@ struct ResultOk2{A<:AbstractArray,B<:AbstractArray}
153153
ok::Bool
154154
end
155155

156+
#! format: on
157+
"""
158+
nif(condition, expression, [else_expression,] ::Val{N}) where {N}
159+
160+
Generate a sequence of `if ... elseif ... else ... end` statements.
161+
This is copied from https://github.com/JuliaLang/julia/pull/55093.
162+
163+
# Arguments
164+
- `condition`: A function that takes an integer between `1` and `N-1` and
165+
returns a boolean condition.
166+
- `expression`: A function that takes an integer between `1` and `N` (or,
167+
only up to `N-1`, if `else_expression` is provided) and is called if
168+
the condition is true.
169+
- `else_expression`: (optional) A function that takes `N` as input
170+
returns an expression to be evaluated if all conditions are false.
171+
- `N`: The number of conditions to check, passed as a `Val{N}` instance.
172+
This function is similar to the `@nif` macro but can be used in cases
173+
where `N` is not known at parse time.
174+
# Examples
175+
For example, here we find the first index of a positive element in a
176+
fixed-size tuple using `nif`:
177+
```jldoctest
178+
julia> x = (0, -1, 1, 0)
179+
(0, -1, 1, 0)
180+
julia> Base.Cartesian.nif(d -> x[d] > 0, d -> d, Val(4))
181+
3
182+
```
183+
"""
184+
@inline function nif(condition::F, expression::G, ::Val{N}) where {F,G,N}
185+
return nif(condition, expression, expression, Val(N))
186+
end
187+
@inline function nif(
188+
condition::F, expression::G, else_expression::H, ::Val{N}
189+
) where {F,G,H,N}
190+
n = N::Int # Can improve inference; see #54544
191+
(n >= 0) || throw(ArgumentError("if statement length should be ≥ 0, got $n"))
192+
if @generated
193+
:(Base.Cartesian.@nif $N d -> condition(d) d -> expression(d) d ->
194+
else_expression(d))
195+
else
196+
for d in 1:(n - 1)
197+
if condition(d)
198+
return expression(d)
199+
end
200+
end
201+
return else_expression(n)
202+
end
203+
end
204+
typeof(function nif end).name.max_methods = UInt8(2)
205+
#! format: off
206+
156207
end

test/test_nif.jl

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# Copied from https://github.com/JuliaLang/julia/pull/55093
2+
@testitem "nif" begin
3+
using DynamicExpressions.UtilsModule: nif
4+
5+
x = (0, -1, 1, 0)
6+
@test nif(d -> x[d] > 0, d -> d, Val(4)) == 3
7+
8+
@test nif(d -> d > 1, d -> "A", d -> "B", Val(1)) == "B"
9+
@test nif(d -> d > 3, d -> "A", d -> "B", Val(3)) == "B"
10+
11+
# Test with N = 0
12+
@test nif(d -> d > 0, d -> "", d -> "A", Val(0)) == "A"
13+
14+
# Specific branch true
15+
@test nif(d -> d == 2, d -> d, d -> "else", Val(3)) == 2
16+
17+
# Test with condition only true for last branch
18+
@test nif(d -> d == 5, d -> "A", d -> "B", Val(5)) == "B"
19+
20+
# Test with bad input:
21+
@test_throws ArgumentError("if statement length should be ≥ 0, got -1") nif(
22+
identity, identity, Val(-1)
23+
)
24+
25+
# Non-Int64 also throws
26+
@test_throws TypeError nif(identity, identity, Val(1.5))
27+
28+
# Make sure all conditions are actually evaluated
29+
result = let c = Ref(0)
30+
nif(d -> (c[] += 1; false), d -> 1, Val(4))
31+
c[]
32+
end
33+
@test result == 3
34+
35+
# Test inference is good
36+
t = ("i am not an int", ntuple(d -> d, Val(10))...)
37+
function extract_from_tuple(t::Tuple, i)
38+
return nif(
39+
d -> d == i,
40+
d -> t[d + 1], # We skip the non-integer element
41+
Val(length(t) - 1),
42+
)
43+
end
44+
# Normally, had we used getindex here, inference would have
45+
# not been able to infer that the return type never includes
46+
# the first element. But since we used an `nif`, the compiler
47+
# knows all possible branches and can infer the correct type.
48+
@test @inferred(extract_from_tuple(t, 3)) == 3
49+
end

test/unittest.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,3 +129,4 @@ include("test_operator_construction_edgecases.jl")
129129
include("test_node_interface.jl")
130130
include("test_expression_math.jl")
131131
include("test_structured_expression.jl")
132+
include("test_nif.jl")

0 commit comments

Comments
 (0)