diff --git a/extensions/pyo3/private/pyo3.bzl b/extensions/pyo3/private/pyo3.bzl index 0e087b55c8..0c936ee7a9 100644 --- a/extensions/pyo3/private/pyo3.bzl +++ b/extensions/pyo3/private/pyo3.bzl @@ -87,10 +87,19 @@ def _py_pyo3_library_impl(ctx): is_windows = extension.basename.endswith(".dll") # https://pyo3.rs/v0.26.0/building-and-distribution#manual-builds - ext = ctx.actions.declare_file("{}{}".format( - ctx.label.name, - ".pyd" if is_windows else ".so", - )) + # Determine the on-disk and logical Python module layout. + module_name = ctx.attr.module if ctx.attr.module else ctx.label.name + + # Convert a dotted prefix (e.g. "foo.bar") into a path ("foo/bar"). + if ctx.attr.module_prefix: + module_prefix_path = ctx.attr.module_prefix.replace(".", "/") + module_relpath = "{}/{}.{}".format(module_prefix_path, module_name, "pyd" if is_windows else "so") + stub_relpath = "{}/{}.pyi".format(module_prefix_path, module_name) + else: + module_relpath = "{}.{}".format(module_name, "pyd" if is_windows else "so") + stub_relpath = "{}.pyi".format(module_name) + + ext = ctx.actions.declare_file(module_relpath) ctx.actions.symlink( output = ext, target_file = extension, @@ -99,10 +108,10 @@ def _py_pyo3_library_impl(ctx): stub = None if _stubs_enabled(ctx.attr.stubs, toolchain): - stub = ctx.actions.declare_file("{}.pyi".format(ctx.label.name)) + stub = ctx.actions.declare_file(stub_relpath) args = ctx.actions.args() - args.add(ctx.label.name, format = "--module_name=%s") + args.add(module_name, format = "--module_name=%s") args.add(ext, format = "--module_path=%s") args.add(stub, format = "--output=%s") ctx.actions.run( @@ -180,6 +189,12 @@ py_pyo3_library = rule( "imports": attr.string_list( doc = "List of import directories to be added to the `PYTHONPATH`.", ), + "module": attr.string( + doc = "The Python module name implemented by this extension.", + ), + "module_prefix": attr.string( + doc = "A dotted Python package prefix for the module.", + ), "stubs": attr.int( doc = "Whether or not to generate stubs. `-1` will default to the global config, `0` will never generate, and `1` will always generate stubs.", default = -1, @@ -218,6 +233,8 @@ def pyo3_extension( stubs = None, version = None, compilation_mode = "opt", + module = None, + module_prefix = None, **kwargs): """Define a PyO3 python extension module. @@ -259,6 +276,8 @@ def pyo3_extension( For more details see [rust_shared_library][rsl]. compilation_mode (str, optional): The [compilation_mode](https://bazel.build/reference/command-line-reference#flag--compilation_mode) value to build the extension for. If set to `"current"`, the current configuration will be used. + module (str, optional): The Python module name implemented by this extension. + module_prefix (str, optional): A dotted Python package prefix for the module. **kwargs (dict): Additional keyword arguments. """ tags = kwargs.pop("tags", []) @@ -318,6 +337,8 @@ def pyo3_extension( compilation_mode = compilation_mode, stubs = stubs_int, imports = imports, + module = module, + module_prefix = module_prefix, tags = tags, visibility = visibility, **kwargs diff --git a/extensions/pyo3/test/module_prefix/BUILD.bazel b/extensions/pyo3/test/module_prefix/BUILD.bazel new file mode 100644 index 0000000000..75dfc8f1a0 --- /dev/null +++ b/extensions/pyo3/test/module_prefix/BUILD.bazel @@ -0,0 +1,17 @@ +load("@rules_python//python:defs.bzl", "py_test") +load("//:defs.bzl", "pyo3_extension") + +pyo3_extension( + name = "module_prefix", + srcs = ["bar.rs"], + edition = "2021", + imports = ["."], + module = "bar", + module_prefix = "foo", +) + +py_test( + name = "module_prefix_import_test", + srcs = ["module_prefix_import_test.py"], + deps = [":module_prefix"], +) diff --git a/extensions/pyo3/test/module_prefix/bar.rs b/extensions/pyo3/test/module_prefix/bar.rs new file mode 100644 index 0000000000..cfddb13a17 --- /dev/null +++ b/extensions/pyo3/test/module_prefix/bar.rs @@ -0,0 +1,12 @@ +use pyo3::prelude::*; + +#[pyfunction] +fn thing() -> PyResult<&'static str> { + Ok("hello from rust") +} + +#[pymodule] +fn bar(m: &Bound<'_, PyModule>) -> PyResult<()> { + m.add_function(wrap_pyfunction!(thing, m)?)?; + Ok(()) +} diff --git a/extensions/pyo3/test/module_prefix/module_prefix_import_test.py b/extensions/pyo3/test/module_prefix/module_prefix_import_test.py new file mode 100644 index 0000000000..389c51d4a3 --- /dev/null +++ b/extensions/pyo3/test/module_prefix/module_prefix_import_test.py @@ -0,0 +1,19 @@ +"""Tests that a pyo3 extension can be imported via a module prefix.""" + +import unittest + +import foo.bar # type: ignore + + +class ModulePrefixImportTest(unittest.TestCase): + """Test Class.""" + + def test_import_and_call(self) -> None: + """Test that a pyo3 extension can be imported via a module prefix.""" + + result = foo.bar.thing() + self.assertEqual("hello from rust", result) + + +if __name__ == "__main__": + unittest.main()