diff --git a/benchmarks/callperf.jl b/benchmarks/callperf.jl new file mode 100644 index 00000000..1f0f4b00 --- /dev/null +++ b/benchmarks/callperf.jl @@ -0,0 +1,35 @@ +using PyCall, BenchmarkTools, DataStructures + +results = OrderedDict{String,Any}() + +let + np = pyimport("numpy") + nprand = np["random"]["rand"] + nprand_pyo(sz...) = pycall(nprand, PyObject, sz...) + ret = PyNULL() + args_lens = (0,3,7,12,17) + arr_sizes = (ntuple(i->1, len) for len in args_lens) + nprand_wraps = [PyFuncWrap(nprand, map(typeof, arr_size)) for arr_size in arr_sizes] + @show typeof(nprand_wraps) + for (i, arr_size) in enumerate(arr_sizes) + nprand_wrap = nprand_wraps[i] + arr_size_str = args_lens[i] < 5 ? "$arr_size" : "$(args_lens[i])*(1,1,...)" + results["nprand_pyo $arr_size_str"] = @benchmark $nprand_pyo($arr_size...) + println("nprand_pyo $arr_size_str:\n"); display(results["nprand_pyo $arr_size_str"]) + println("--------------------------------------------------") + + results["nprand_wrap $arr_size_str"] = @benchmark $nprand_wrap($arr_size...) + println("nprand_wrap $arr_size_str:\n"); display(results["nprand_wrap $arr_size_str"]) + println("--------------------------------------------------") + + # args already set by nprand_wrap calls above + results["nprand_wrap_noargs $arr_size_str"] = @benchmark $nprand_wrap() + println("nprand_wrap_noargs $arr_size_str:\n"); display(results["nprand_wrap_noargs $arr_size_str"]) + println("--------------------------------------------------") + end +end + +println("") +println("Mean times") +println("----------") +foreach((r)->println(rpad(r[1],33), ": ", mean(r[2])), results) diff --git a/src/PyCall.jl b/src/PyCall.jl index 52a163bd..687bd227 100644 --- a/src/PyCall.jl +++ b/src/PyCall.jl @@ -10,7 +10,8 @@ export pycall, pyimport, pybuiltin, PyObject, PyReverseDims, pyisinstance, pywrap, pytypeof, pyeval, PyVector, pystring, pystr, pyrepr, pyraise, pytype_mapping, pygui, pygui_start, pygui_stop, pygui_stop_all, @pylab, set!, PyTextIO, @pysym, PyNULL, @pydef, - pyimport_conda, @py_str, @pywith, @pycall, pybytes, pyfunction, pyfunctionret + pyimport_conda, @py_str, @pywith, @pycall, pybytes, pyfunction, pyfunctionret, + PyFuncWrap, setarg!, setargs! import Base: size, ndims, similar, copy, getindex, setindex!, stride, convert, pointer, summary, convert, show, haskey, keys, values, @@ -170,6 +171,7 @@ include("pytype.jl") include("pyiterator.jl") include("pyclass.jl") include("callback.jl") +include("pyfuncwrap.jl") include("io.jl") ######################################################################### diff --git a/src/conversions.jl b/src/conversions.jl index 014b8f26..5d5b8bbf 100644 --- a/src/conversions.jl +++ b/src/conversions.jl @@ -174,6 +174,11 @@ end # somewhat annoying to get the length and types in a tuple type # ... would be better not to have to use undocumented internals! +function tuplen(T::DataType) + isvatuple(T) && ArgumentError("can't determine length of vararg tuple: $T") + return length(T.parameters) +end +tuplen(T::UnionAll) = tuplen(T.body) istuplen(T,isva,n) = isva ? n ≥ length(T.parameters)-1 : n == length(T.parameters) function tuptype(T::DataType,isva,i) if isva && i ≥ length(T.parameters) diff --git a/src/pyfuncwrap.jl b/src/pyfuncwrap.jl new file mode 100644 index 00000000..6d118cbc --- /dev/null +++ b/src/pyfuncwrap.jl @@ -0,0 +1,91 @@ +struct PyFuncWrap{P<:Union{PyObject,PyPtr}, AT<:Tuple, N, RT} + o::P + oargs::Vector{PyObject} + pyargsptr::PyPtr + ret::PyObject +end + +""" +``` +PyFuncWrap(o::P, argtypes::Tuple #= of Types =#, returntype::Type) +``` + +Wrap a callable PyObject/PyPtr to reduce the number of allocations made for +passing its arguments, and its return value, sometimes providing a speedup. +Mainly useful for functions called in a tight loop, particularly if most or +all of the arguments to the function don't change. +``` +@pyimport numpy as np +rand22fn = PyFuncWrap(np.random["rand"], (Int, Int), PyArray) +setargs!(rand22fn, 2, 2) +for i in 1:10^9 + arr = rand22fn() + ... +end +``` +""" +function PyFuncWrap(o::P, argtypes::Tuple{Vararg{<:Union{Tuple, Type}}}, + returntype::Type{RT}=PyObject) where {P<:Union{PyObject,PyPtr}, RT} + AT = typeof(argtypes) + isvatuple(AT) && throw(ArgumentError("Vararg functions not supported, arg signature provided: $AT")) + N = tuplen(AT) + oargs = Array{PyObject}(N) + pyargsptr = ccall((@pysym :PyTuple_New), PyPtr, (Int,), N) + return PyFuncWrap{P, AT, N, RT}(o, oargs, pyargsptr, PyNULL()) +end + +""" +``` +setargs!(pf::PyFuncWrap, args...) +``` +Set the arguments to a python function wrapped in a PyFuncWrap, and convert them +to `PyObject`s that can be passed directly to python when the function is +called. After the arguments have been set, the function can be efficiently +called with `pf()` +""" +function setargs!(pf::PyFuncWrap{P, AT, N, RT}, args...) where {P, AT, RT, N} + for i = 1:N + setarg!(pf, args[i], i) + end + nothing +end + +""" +``` +setarg!(pf::PyFuncWrap, arg, i::Integer=1) +``` +Set the `i`th argument to a python function wrapped in a PyFuncWrap, and convert +it to a `PyObject` that can be passed directly to python when the function is +called. Useful if a function takes multiple arguments, but only one or two of +them change, when calling the function in a tight loop +""" +function setarg!(pf::PyFuncWrap{P, AT, N, RT}, arg, i::Integer=1) where {P, AT, N, RT} + pf.oargs[i] = PyObject(arg) + @pycheckz ccall((@pysym :PyTuple_SetItem), Cint, + (PyPtr,Int,PyPtr), pf.pyargsptr, i-1, pf.oargs[i]) + pyincref(pf.oargs[i]) # PyTuple_SetItem steals the reference + nothing +end + +function (pf::PyFuncWrap{P, AT, N, RT})(args...) where {P, AT, N, RT} + setargs!(pf, args...) + return pf() +end + +""" +Warning: if pf(args) or setargs(pf, ...) hasn't been called yet, this will likely segfault +""" +function (pf::PyFuncWrap{P, AT, N, RT})() where {P, AT, N, RT} + sigatomic_begin() + try + kw = C_NULL + retptr = ccall((@pysym :PyObject_Call), PyPtr, (PyPtr,PyPtr,PyPtr), pf.o, + pf.pyargsptr, kw) + pyincref_(retptr) + pydecref(pf.ret) + pf.ret.o = retptr + finally + sigatomic_end() + end + convert(RT, pf.ret) +end diff --git a/test/runtests.jl b/test/runtests.jl index b7dd2e8f..8ecbe0c9 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -538,3 +538,5 @@ end @test pyfunctionret(factorial, nothing, Int)(3) === nothing @test PyCall.is_pyjlwrap(pycall(pyfunctionret(factorial, Any, Int), PyObject, 3)) end + +include("test_pyfuncwrap.jl") \ No newline at end of file diff --git a/test/test_pyfuncwrap.jl b/test/test_pyfuncwrap.jl new file mode 100644 index 00000000..ed213d54 --- /dev/null +++ b/test/test_pyfuncwrap.jl @@ -0,0 +1,29 @@ +using Compat.Test, PyCall + +@testset "PyFuncWrap" begin + np = pyimport("numpy") + ops = pyimport("operator") + eq = ops["eq"] + npzeros = np["zeros"] + npzeros_pyo(sz, dtype="d", order="F") = pycall(npzeros, PyObject, sz, dtype, order) + npzeros_pyany(sz, dtype="d", order="F") = pycall(npzeros, PyAny, sz, dtype, order) + npzeros_pyarray(sz, dtype="d", order="F") = pycall(npzeros, PyArray, sz, dtype, order) + + # PyObject is default returntype + npzeros2dwrap_pyo = PyFuncWrap(npzeros, ((Int, Int), String, String)) + npzeros2dwrap_pyany = PyFuncWrap(npzeros, ((Int, Int), String, String), PyAny) + npzeros2dwrap_pyarray = PyFuncWrap(npzeros, ((Int, Int), String, String), PyArray) + + arr_size = (2,2) + + # all args + @test np["array_equal"](npzeros2dwrap_pyo(arr_size, "d", "F"), npzeros_pyo(arr_size)) + # args already set + @test np["array_equal"](npzeros2dwrap_pyo(), npzeros_pyo(arr_size)) + + @test all(npzeros2dwrap_pyany(arr_size, "d", "F") .== npzeros_pyany(arr_size)) + @test all(npzeros2dwrap_pyany() .== npzeros_pyany(arr_size)) + + @test all(npzeros2dwrap_pyarray(arr_size, "d", "F") .== npzeros_pyarray(arr_size)) + @test all(npzeros2dwrap_pyarray() .== npzeros_pyarray(arr_size)) +end