@@ -4,16 +4,6 @@ using ..UtilsModule: deprecate_varmap
44using .. OperatorEnumModule: AbstractOperatorEnum
55using .. NodeModule: AbstractExpressionNode, tree_mapreduce
66
7- const OP_NAMES = Base. ImmutableDict (
8- " safe_log" => " log" ,
9- " safe_log2" => " log2" ,
10- " safe_log10" => " log10" ,
11- " safe_log1p" => " log1p" ,
12- " safe_acosh" => " acosh" ,
13- " safe_sqrt" => " sqrt" ,
14- " safe_pow" => " ^" ,
15- )
16-
177function dispatch_op_name (:: Val{deg} , :: Nothing , idx):: Vector{Char} where {deg}
188 if deg == 1
199 return vcat (collect (" unary_operator[" ), collect (string (idx)), [' ]' ])
@@ -23,34 +13,38 @@ function dispatch_op_name(::Val{deg}, ::Nothing, idx)::Vector{Char} where {deg}
2313end
2414function dispatch_op_name (:: Val{deg} , operators:: AbstractOperatorEnum , idx) where {deg}
2515 if deg == 1
26- return get_op_name (operators. unaops[idx]):: Vector{Char}
16+ return collect ( get_op_name (operators. unaops[idx]):: String )
2717 else
28- return get_op_name (operators. binops[idx]):: Vector{Char}
18+ return collect ( get_op_name (operators. binops[idx]):: String )
2919 end
3020end
3121
32- @generated function get_op_name (op:: F ):: Vector{Char} where {F}
22+ const OP_NAME_CACHE = (; x= Dict {UInt64,String} (), lock= Threads. SpinLock ())
23+
24+ function get_op_name (op:: F ) where {F}
25+ h = hash (op)
26+ lock (OP_NAME_CACHE. lock)
3327 try
34- # Bit faster to just cache the name of the operator:
35- op_s = if F <: Broadcast.BroadcastFunction
36- string (F. parameters[1 ]. instance) * ' .'
37- else
38- string (F. instance)
28+ cache = OP_NAME_CACHE. x
29+ if haskey (cache, h)
30+ return cache[h]
3931 end
40- if length ( op_s) == 2 && op_s[ 1 ] in ( ' + ' , ' - ' , ' * ' , ' / ' , ' ^ ' ) && op_s[ 2 ] == ' . '
41- op_s = ' . ' * op_s[ 1 ]
42- end
43- out = collect ( get (OP_NAMES, op_s, op_s))
44- return :( $ out )
45- catch
46- end
47- return quote
48- op_s = typeof (op) <: Broadcast.BroadcastFunction ? string (op . f) * ' . ' : string (op)
49- if length (op_s) == 2 && op_s[ 1 ] in ( ' + ' , ' - ' , ' * ' , ' / ' , ' ^ ' ) && op_s[ 2 ] == ' . '
50- op_s = ' . ' * op_s[ 1 ]
32+ op_s = if op isa Broadcast . BroadcastFunction
33+ base_op_s = string (op . f)
34+ if length (base_op_s) == 1 && first (base_op_s) in ( ' + ' , ' - ' , ' * ' , ' / ' , ' ^ ' )
35+ # Like `.+`
36+ string ( ' . ' , base_op_s )
37+ else
38+ # Like `cos.`
39+ string (base_op_s, ' . ' )
40+ end
41+ else
42+ string (op)
5143 end
52- out = collect (get (OP_NAMES, op_s, op_s))
53- return out
44+ cache[h] = op_s
45+ return op_s
46+ finally
47+ unlock (OP_NAME_CACHE. lock)
5448 end
5549end
5650
0 commit comments