From 70f3a9e2f1ee1dad1262d33fe856627e46fd05a9 Mon Sep 17 00:00:00 2001 From: Thomas Pellissier-Tanon Date: Fri, 27 Feb 2026 09:22:28 +0100 Subject: [PATCH 1/4] Stub: allow custom imports These imports are written in a small DSL like: ```rust #[pymodule(stubs = { from datetime import datetime as dt, time; from uuid import UUID; })] ``` Then parsed, sent as an AST inside the introspection data (following the same AST format as the type hints) and serialized by the introspection crate that merges these imports with the auto generated ones The `#[pymodule]` parameter is named `stub` because we might include some other features in the future like protocols --- newsfragments/5877.added.md | 1 + pyo3-introspection/src/introspection.rs | 59 +++++- pyo3-introspection/src/model.rs | 25 +++ pyo3-introspection/src/stubs.rs | 222 +++++++++++++++++++--- pyo3-macros-backend/src/attributes.rs | 6 + pyo3-macros-backend/src/introspection.rs | 31 +-- pyo3-macros-backend/src/json.rs | 75 ++++++++ pyo3-macros-backend/src/lib.rs | 4 + pyo3-macros-backend/src/module.rs | 18 +- pyo3-macros-backend/src/py_stubs.rs | 229 +++++++++++++++++++++++ pytests/src/annotations.rs | 30 +++ pytests/src/lib.rs | 6 + pytests/src/pyfunctions.rs | 14 -- pytests/stubs/annotations.pyi | 9 + pytests/stubs/pyfunctions.pyi | 3 - tests/test_compile_error.rs | 1 + 16 files changed, 665 insertions(+), 68 deletions(-) create mode 100644 newsfragments/5877.added.md create mode 100644 pyo3-macros-backend/src/json.rs create mode 100644 pyo3-macros-backend/src/py_stubs.rs create mode 100644 pytests/src/annotations.rs create mode 100644 pytests/stubs/annotations.pyi diff --git a/newsfragments/5877.added.md b/newsfragments/5877.added.md new file mode 100644 index 00000000000..5666e5d5c67 --- /dev/null +++ b/newsfragments/5877.added.md @@ -0,0 +1 @@ +Introspection: allow to set custom stub imports in `#[pymodule]` macro \ No newline at end of file diff --git a/pyo3-introspection/src/introspection.rs b/pyo3-introspection/src/introspection.rs index 34d23488729..7387acd6065 100644 --- a/pyo3-introspection/src/introspection.rs +++ b/pyo3-introspection/src/introspection.rs @@ -1,6 +1,6 @@ use crate::model::{ - Argument, Arguments, Attribute, Class, Constant, Expr, Function, Module, Operator, - VariableLengthArgument, + Argument, Arguments, Attribute, Class, Constant, Expr, Function, ImportAlias, Module, Operator, + Statement, VariableLengthArgument, }; use anyhow::{anyhow, bail, ensure, Context, Result}; use goblin::elf::section_header::SHN_XINDEX; @@ -52,6 +52,7 @@ fn parse_chunks(chunks: &[Chunk], main_module_name: &str) -> Result { members, doc, incomplete, + stubs, } = chunk { if name == main_module_name { @@ -65,6 +66,7 @@ fn parse_chunks(chunks: &[Chunk], main_module_name: &str) -> Result { name, members, *incomplete, + stubs, doc.as_deref(), &chunks_by_id, &chunks_by_parent, @@ -82,6 +84,7 @@ fn convert_module( name: &str, members: &[String], mut incomplete: bool, + stubs: &[ChunkStatement], docstring: Option<&str>, chunks_by_id: &HashMap<&str, &Chunk>, chunks_by_parent: &HashMap<&str, Vec<&Chunk>>, @@ -114,6 +117,35 @@ fn convert_module( functions, attributes, incomplete, + stubs: stubs + .iter() + .map(|statement| match statement { + ChunkStatement::ImportFrom { + module, + names, + level, + } => Statement::ImportFrom { + module: module.clone(), + names: names + .iter() + .map(|alias| ImportAlias { + name: alias.name.clone(), + asname: alias.asname.clone(), + }) + .collect(), + level: *level, + }, + ChunkStatement::Import { names } => Statement::Import { + names: names + .iter() + .map(|alias| ImportAlias { + name: alias.name.clone(), + asname: alias.asname.clone(), + }) + .collect(), + }, + }) + .collect(), docstring: docstring.map(Into::into), }) } @@ -139,12 +171,14 @@ fn convert_members<'a>( members, incomplete, doc, + stubs, } => { modules.push(convert_module( id, name, members, *incomplete, + stubs, doc.as_deref(), chunks_by_id, chunks_by_parent, @@ -686,6 +720,8 @@ enum Chunk { #[serde(default)] doc: Option, incomplete: bool, + #[serde(default)] + stubs: Vec, }, Class { id: String, @@ -753,6 +789,25 @@ struct ChunkArgument { annotation: Option, } +#[derive(Deserialize)] +#[serde(tag = "type", rename_all = "lowercase")] +enum ChunkStatement { + ImportFrom { + module: String, + names: Vec, + level: usize, + }, + Import { + names: Vec, + }, +} + +#[derive(Deserialize)] +struct ChunkAlias { + name: String, + asname: Option, +} + #[derive(Deserialize)] #[serde(tag = "type", rename_all = "lowercase")] enum ChunkExpr { diff --git a/pyo3-introspection/src/model.rs b/pyo3-introspection/src/model.rs index b6ca8d28a8d..f30e93c8aa8 100644 --- a/pyo3-introspection/src/model.rs +++ b/pyo3-introspection/src/model.rs @@ -6,6 +6,7 @@ pub struct Module { pub functions: Vec, pub attributes: Vec, pub incomplete: bool, + pub stubs: Vec, pub docstring: Option, } @@ -74,6 +75,30 @@ pub struct VariableLengthArgument { pub annotation: Option, } +/// A python statement +/// +/// This is the `stmt` production of the [Python `ast` module grammar](https://docs.python.org/3/library/ast.html#abstract-grammar) +#[derive(Debug, Eq, PartialEq, Clone, Hash)] +pub enum Statement { + /// `from {module} import {names}` + ImportFrom { + module: String, + names: Vec, + level: usize, + }, + /// `import {names}` + Import { names: Vec }, +} + +/// A python import alias `{name} as {asname}` +/// +/// This is the `alias` production of the [Python `ast` module grammar](https://docs.python.org/3/library/ast.html#abstract-grammar) +#[derive(Debug, Eq, PartialEq, Clone, Hash)] +pub struct ImportAlias { + pub name: String, + pub asname: Option, +} + /// A python expression /// /// This is the `expr` production of the [Python `ast` module grammar](https://docs.python.org/3/library/ast.html#abstract-grammar) diff --git a/pyo3-introspection/src/stubs.rs b/pyo3-introspection/src/stubs.rs index 9877652cc40..9f285e47214 100644 --- a/pyo3-introspection/src/stubs.rs +++ b/pyo3-introspection/src/stubs.rs @@ -1,6 +1,6 @@ use crate::model::{ - Argument, Arguments, Attribute, Class, Constant, Expr, Function, Module, Operator, - VariableLengthArgument, + Argument, Arguments, Attribute, Class, Constant, Expr, Function, ImportAlias, Module, Operator, + Statement, VariableLengthArgument, }; use std::collections::{BTreeMap, BTreeSet, HashMap}; use std::fmt::Write; @@ -95,7 +95,27 @@ fn module_stubs(module: &Module, parents: &[&str]) -> String { if let Some(docstring) = &module.docstring { final_elements.push(format!("\"\"\"\n{docstring}\n\"\"\"")); } - final_elements.extend(imports.imports); + for (module, names) in &imports.imports { + let mut line = String::new(); + if let Some(module) = module { + line.push_str("from "); + line.push_str(module); + line.push(' '); + } + line.push_str("import"); + for (i, name) in names.iter().enumerate() { + if i > 0 { + line.push(','); + } + line.push(' '); + line.push_str(&name.name); + if let Some(asname) = &name.asname { + line.push_str(" as "); + line.push_str(asname); + } + } + final_elements.push(line); + } final_elements.extend(elements); let mut output = String::new(); @@ -294,7 +314,7 @@ fn variable_length_argument_stub(argument: &VariableLengthArgument, imports: &Im #[derive(Default)] struct Imports { /// Import lines ready to use - imports: Vec, + imports: BTreeMap, Vec>, /// Renaming map: from module name and member name return the name to use in type hints renaming: BTreeMap<(String, String), String>, } @@ -311,7 +331,7 @@ impl Imports { let mut elements_used_in_annotations = ElementsUsedInAnnotations::new(); elements_used_in_annotations.walk_module(module); - let mut imports = Vec::new(); + let mut imports = BTreeMap::, Vec>::new(); let mut renaming = BTreeMap::new(); let mut local_name_to_module_and_attribute = BTreeMap::new(); @@ -334,10 +354,38 @@ impl Imports { local_name_to_module_and_attribute .insert(name.clone(), (current_module_name.clone(), name.clone())); } - // We don't process the current module elements, no need to care about them - local_name_to_module_and_attribute.remove(¤t_module_name); - // We process then imports, normalizing local imports + // Also, we insert the imports from the user-provided stub + for statement in &module.stubs { + let (module, names) = match statement { + Statement::Import { names } => (None, names), + Statement::ImportFrom { + module, + names, + level, + } => { + // We build the python module relative path + let mut module = module.clone(); + for _ in 0..*level { + module = format!(".{module}") + } + (Some(module), names) + } + }; + for name in names { + let module_and_name = (module.clone().unwrap_or_default(), name.name.clone()); + let local_name = name.asname.as_ref().unwrap_or(&name.name).clone(); + local_name_to_module_and_attribute + .insert(local_name.clone(), module_and_name.clone()); + renaming.insert(module_and_name, local_name); + } + imports + .entry(module) + .or_default() + .extend(names.iter().cloned()); + } + + // Finally, We process imports from built-in annotations (always absolute) for (module, attrs) in &elements_used_in_annotations.module_to_name { let mut import_for_module = Vec::new(); for attr in attrs { @@ -345,6 +393,18 @@ impl Imports { let (root_attr, attr_path) = attr .split_once('.') .map_or((attr.as_str(), None), |(root, path)| (root, Some(path))); + + if let Some(local_name) = renaming.get(&(module.clone(), root_attr.to_owned())) { + // it's already imported, we make sure to get a renaming for the nested class if relevant + if let Some(attr_path) = &attr_path { + renaming.insert( + (module.clone(), attr.clone()), + format!("{local_name}.{attr_path}"), + ); + } + continue; + } + let mut local_name = root_attr.to_owned(); let mut already_imported = false; while let Some((possible_conflict_module, possible_conflict_attr)) = @@ -383,21 +443,31 @@ impl Imports { let is_not_aliased_builtin = module == "builtins" && local_name == root_attr; if !is_not_aliased_builtin { import_for_module.push(if local_name == root_attr { - local_name + ImportAlias { + name: local_name, + asname: None, + } } else { - format!("{root_attr} as {local_name}") + ImportAlias { + name: root_attr.into(), + asname: Some(local_name), + } }); } } } if !import_for_module.is_empty() { - imports.push(format!( - "from {module} import {}", - import_for_module.join(", ") - )); + imports + .entry(Some(module.clone())) + .or_default() + .extend(import_for_module); } } - imports.sort(); // We make sure they are sorted + + // We sort imports + for names in imports.values_mut() { + names.sort_by(|l, r| (&l.name, &l.asname).cmp(&(&r.name, &r.asname))); + } Self { imports, renaming } } @@ -628,7 +698,7 @@ impl ElementsUsedInAnnotations { #[cfg(test)] mod tests { use super::*; - use crate::model::Arguments; + use crate::model::{Arguments, ImportAlias, Statement}; #[test] fn function_stubs_with_variable_length() { @@ -769,6 +839,14 @@ mod tests { value: Box::new(Expr::Name { id: "foo".into() }), attr: "B".into(), }, + Expr::Attribute { + value: Box::new(Expr::Name { id: "foo".into() }), + attr: "C".into(), + }, + Expr::Attribute { + value: Box::new(Expr::Name { id: "foo".into() }), + attr: "D".into(), + }, Expr::Attribute { value: Box::new(Expr::Name { id: "bat".into() }), attr: "A".into(), @@ -831,22 +909,116 @@ mod tests { }], attributes: Vec::new(), incomplete: true, + stubs: vec![ + Statement::ImportFrom { + module: "foo".into(), + names: vec![ + ImportAlias { + name: "A".into(), + asname: Some("AAlt".into()), + }, + ImportAlias { + name: "B".into(), + asname: Some("B2".into()), + }, + ImportAlias { + name: "C".into(), + asname: None, + }, + ], + level: 0, + }, + Statement::Import { + names: vec![ImportAlias { + name: "bat".into(), + asname: None, + }], + }, + Statement::ImportFrom { + module: "bat".into(), + names: vec![ImportAlias { + name: "D".into(), + asname: None, + }], + level: 0, + }, + ], docstring: None, }, &["foo"], ); assert_eq!( - &imports.imports, - &[ - "from _typeshed import Incomplete", - "from bat import A as A2", - "from builtins import int as int2", - "from foo import A as A3, B", - "from typing import final" - ] + imports.imports, + BTreeMap::from([ + ( + None, + vec![ImportAlias { + name: "bat".into(), + asname: None + }] + ), + ( + Some("_typeshed".to_string()), + vec![ImportAlias { + name: "Incomplete".into(), + asname: None + }] + ), + ( + Some("bat".into()), + vec![ + ImportAlias { + name: "A".into(), + asname: Some("A2".into()) + }, + ImportAlias { + name: "D".into(), + asname: None + } + ] + ), + ( + Some("builtins".into()), + vec![ImportAlias { + name: "int".into(), + asname: Some("int2".into()) + }] + ), + ( + Some("foo".into()), + vec![ + ImportAlias { + name: "A".into(), + asname: Some("AAlt".into()) + }, + ImportAlias { + name: "B".into(), + asname: Some("B2".into()) + }, + ImportAlias { + name: "C".into(), + asname: None + }, + ImportAlias { + name: "D".into(), + asname: Some("D2".into()) + } + ] + ), + ( + Some("typing".into()), + vec![ImportAlias { + name: "final".into(), + asname: None + }] + ), + ]) ); let mut output = String::new(); imports.serialize_expr(&big_type, &mut output); - assert_eq!(output, "dict[A, (A3.C, A3.D, B, A2, int, int2, float)]"); + assert_eq!( + output, + "dict[A, (AAlt.C, AAlt.D, B2, C, D2, A2, int, int2, float)]" + ); } } diff --git a/pyo3-macros-backend/src/attributes.rs b/pyo3-macros-backend/src/attributes.rs index 9894c463628..221a2a43529 100644 --- a/pyo3-macros-backend/src/attributes.rs +++ b/pyo3-macros-backend/src/attributes.rs @@ -11,6 +11,8 @@ use syn::{ }; use crate::combine_errors::CombineErrors; +#[cfg(feature = "experimental-inspect")] +use crate::py_stubs::PyStubs; pub mod kw { syn::custom_keyword!(annotation); @@ -56,6 +58,8 @@ pub mod kw { syn::custom_keyword!(category); syn::custom_keyword!(from_py_object); syn::custom_keyword!(skip_from_py_object); + #[cfg(feature = "experimental-inspect")] + syn::custom_keyword!(stubs); } fn take_int(read: &mut &str, tracker: &mut usize) -> String { @@ -349,6 +353,8 @@ pub type TextSignatureAttribute = KeywordAttribute; pub type SubmoduleAttribute = kw::submodule; pub type GILUsedAttribute = KeywordAttribute; +#[cfg(feature = "experimental-inspect")] +pub type StubsAttribute = KeywordAttribute; impl Parse for KeywordAttribute { fn parse(input: ParseStream<'_>) -> Result { diff --git a/pyo3-macros-backend/src/introspection.rs b/pyo3-macros-backend/src/introspection.rs index aad7eb1f267..35862a9430b 100644 --- a/pyo3-macros-backend/src/introspection.rs +++ b/pyo3-macros-backend/src/introspection.rs @@ -8,8 +8,10 @@ //! The JSON blobs format must be synchronized with the `pyo3_introspection::introspection.rs::Chunk` //! type that is used to parse them. +use crate::json::escape_json_string; use crate::method::{FnArg, RegularArg}; use crate::py_expr::PyExpr; +use crate::py_stubs::PyStubs; use crate::pyfunction::FunctionSignature; use crate::utils::{PyO3CratePath, PythonDoc, StrOrExpr}; use proc_macro2::{Span, TokenStream}; @@ -17,7 +19,6 @@ use quote::{format_ident, quote, ToTokens}; use std::borrow::Cow; use std::collections::hash_map::DefaultHasher; use std::collections::HashMap; -use std::fmt::Write; use std::hash::{Hash, Hasher}; use std::mem::take; use std::sync::atomic::{AtomicUsize, Ordering}; @@ -32,6 +33,7 @@ pub fn module_introspection_code<'a>( members_cfg_attrs: impl IntoIterator>, doc: Option<&PythonDoc>, incomplete: bool, + extra_stubs: Option<&PyStubs>, ) -> TokenStream { let mut desc = HashMap::from([ ("type", IntrospectionNode::String("module".into())), @@ -55,6 +57,9 @@ pub fn module_introspection_code<'a>( if let Some(doc) = doc { desc.insert("doc", IntrospectionNode::Doc(doc)); } + if let Some(stubs) = extra_stubs { + desc.insert("stubs", IntrospectionNode::Stubs(stubs)); + } IntrospectionNode::Map(desc).emit(pyo3_crate_path) } @@ -353,6 +358,7 @@ enum IntrospectionNode<'a> { IntrospectionId(Option>), TypeHint(Cow<'a, PyExpr>), Doc(&'a PythonDoc), + Stubs(&'a PyStubs), Map(HashMap<&'static str, IntrospectionNode<'a>>), List(Vec>), } @@ -411,6 +417,9 @@ impl IntrospectionNode<'_> { } content.push_str("\""); } + Self::Stubs(stubs) => { + content.push_str(&stubs.as_json().to_string()); + } Self::Map(map) => { content.push_str("{"); for (i, (key, value)) in map.into_iter().enumerate() { @@ -612,23 +621,3 @@ fn ident_to_type(ident: &Ident) -> Cow<'static, Type> { .into(), ) } - -fn escape_json_string(value: &str) -> String { - let mut output = String::with_capacity(value.len()); - for c in value.chars() { - match c { - '\\' => output.push_str("\\\\"), - '"' => output.push_str("\\\""), - '\x08' => output.push_str("\\b"), - '\x0C' => output.push_str("\\f"), - '\n' => output.push_str("\\n"), - '\r' => output.push_str("\\r"), - '\t' => output.push_str("\\t"), - c @ '\0'..='\x1F' => { - write!(output, "\\u{:0>4x}", u32::from(c)).unwrap(); - } - c => output.push(c), - } - } - output -} diff --git a/pyo3-macros-backend/src/json.rs b/pyo3-macros-backend/src/json.rs new file mode 100644 index 00000000000..dbb058fff65 --- /dev/null +++ b/pyo3-macros-backend/src/json.rs @@ -0,0 +1,75 @@ +//! JSON-related utilities + +use std::borrow::Cow; +use std::collections::HashMap; +use std::fmt; +use std::fmt::Write as _; + +pub enum JsonValue { + String(Cow<'static, str>), + Number(i16), + Array(Vec), + Object(HashMap<&'static str, JsonValue>), +} + +impl fmt::Display for JsonValue { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::String(value) => { + f.write_char('"')?; + write_escaped_json_string(value, f)?; + f.write_char('"') + } + Self::Number(value) => value.fmt(f), + JsonValue::Array(values) => { + f.write_char('[')?; + for (i, value) in values.iter().enumerate() { + if i > 0 { + f.write_char(',')?; + } + value.fmt(f)?; + } + f.write_char(']') + } + JsonValue::Object(key_values) => { + f.write_char('{')?; + for (i, (key, value)) in key_values.iter().enumerate() { + if i > 0 { + f.write_char(',')?; + } + f.write_char('"')?; + write_escaped_json_string(key, f)?; + f.write_char('"')?; + f.write_char(':')?; + value.fmt(f)?; + } + f.write_char('}') + } + } + } +} + +pub fn escape_json_string(value: &str) -> String { + let mut output = String::with_capacity(value.len()); + write_escaped_json_string(value, &mut output).unwrap(); + output +} + +fn write_escaped_json_string(value: &str, output: &mut impl fmt::Write) -> fmt::Result { + for c in value.chars() { + match c { + '\\' => output.write_str("\\\\"), + '"' => output.write_str("\\\""), + '\x08' => output.write_str("\\b"), + '\x0C' => output.write_str("\\f"), + '\n' => output.write_str("\\n"), + '\r' => output.write_str("\\r"), + '\t' => output.write_str("\\t"), + c @ '\0'..='\x1F' => { + write!(output, "\\u{:0>4x}", u32::from(c)) + } + c => output.write_char(c), + }?; + } + Ok(()) +} diff --git a/pyo3-macros-backend/src/lib.rs b/pyo3-macros-backend/src/lib.rs index a90fa73678e..87479c956d9 100644 --- a/pyo3-macros-backend/src/lib.rs +++ b/pyo3-macros-backend/src/lib.rs @@ -15,12 +15,16 @@ mod frompyobject; mod intopyobject; #[cfg(feature = "experimental-inspect")] mod introspection; +#[cfg(feature = "experimental-inspect")] +mod json; mod konst; mod method; mod module; mod params; #[cfg(feature = "experimental-inspect")] mod py_expr; +#[cfg(feature = "experimental-inspect")] +mod py_stubs; mod pyclass; mod pyfunction; mod pyimpl; diff --git a/pyo3-macros-backend/src/module.rs b/pyo3-macros-backend/src/module.rs index b3a127a4ba5..ec6c4f3c9d0 100644 --- a/pyo3-macros-backend/src/module.rs +++ b/pyo3-macros-backend/src/module.rs @@ -1,5 +1,7 @@ //! Code generation for the function that initializes a python module and adds classes and function. +#[cfg(feature = "experimental-inspect")] +use crate::attributes::StubsAttribute; #[cfg(feature = "experimental-inspect")] use crate::introspection::{ attribute_introspection_code, introspection_id_const, module_introspection_code, @@ -38,6 +40,8 @@ pub struct PyModuleOptions { module: Option, submodule: Option, gil_used: Option, + #[cfg(feature = "experimental-inspect")] + stubs: Option, } impl Parse for PyModuleOptions { @@ -83,9 +87,9 @@ impl PyModuleOptions { submodule, " (it is implicitly always specified for nested modules)" ), - PyModulePyO3Option::GILUsed(gil_used) => { - set_option!(gil_used) - } + PyModulePyO3Option::GILUsed(gil_used) => set_option!(gil_used), + #[cfg(feature = "experimental-inspect")] + PyModulePyO3Option::Stubs(stubs) => set_option!(stubs), } Ok(()) @@ -383,6 +387,7 @@ pub fn pymodule_module_impl( &module_items_cfg_attrs, doc.as_ref(), pymodule_init.is_some(), + options.stubs.as_ref().map(|a| &a.value), ); #[cfg(not(feature = "experimental-inspect"))] let introspection = quote! {}; @@ -471,6 +476,7 @@ pub fn pymodule_function_impl( &[], doc.as_ref(), true, + options.stubs.as_ref().map(|a| &a.value), ); #[cfg(not(feature = "experimental-inspect"))] let introspection = quote! {}; @@ -727,6 +733,8 @@ enum PyModulePyO3Option { Name(NameAttribute), Module(ModuleAttribute), GILUsed(GILUsedAttribute), + #[cfg(feature = "experimental-inspect")] + Stubs(StubsAttribute), } impl Parse for PyModulePyO3Option { @@ -743,6 +751,10 @@ impl Parse for PyModulePyO3Option { } else if lookahead.peek(attributes::kw::gil_used) { input.parse().map(PyModulePyO3Option::GILUsed) } else { + #[cfg(feature = "experimental-inspect")] + if lookahead.peek(attributes::kw::stubs) { + return input.parse().map(PyModulePyO3Option::Stubs); + } Err(lookahead.error()) } } diff --git a/pyo3-macros-backend/src/py_stubs.rs b/pyo3-macros-backend/src/py_stubs.rs new file mode 100644 index 00000000000..f58cf405dbf --- /dev/null +++ b/pyo3-macros-backend/src/py_stubs.rs @@ -0,0 +1,229 @@ +//! Parsing and serialization code for custom type stubs + +use crate::json::JsonValue; +use proc_macro2::{Ident, TokenStream}; +use quote::ToTokens; +use std::collections::HashMap; +use syn::ext::IdentExt; +use syn::parse::{Parse, ParseStream}; +use syn::punctuated::Punctuated; +use syn::token::Brace; +use syn::{braced, Token}; + +mod kw { + syn::custom_keyword!(import); + syn::custom_keyword!(from); +} + +/// Custom provided stubs in #[pymodule] +pub struct PyStubs { + bracket_token: Brace, + imports: Punctuated, +} + +impl PyStubs { + /// Returns a JSON object following the https://docs.python.org/fr/3/library/ast.html syntax tree + pub fn as_json(&self) -> JsonValue { + JsonValue::Array(self.imports.iter().map(|i| i.as_json()).collect()) + } +} + +impl Parse for PyStubs { + fn parse(input: ParseStream<'_>) -> syn::Result { + let content; + Ok(Self { + bracket_token: braced!(content in input), + imports: content.parse_terminated(PyStatement::parse, Token![;])?, + }) + } +} + +impl ToTokens for PyStubs { + fn to_tokens(&self, tokens: &mut TokenStream) { + self.bracket_token + .surround(tokens, |tokens| self.imports.to_tokens(tokens)) + } +} + +/// A Python statement +enum PyStatement { + ImportFrom(PyImportFrom), + Import(PyImport), +} + +impl PyStatement { + pub fn as_json(&self) -> JsonValue { + match self { + Self::ImportFrom(s) => s.as_json(), + Self::Import(s) => s.as_json(), + } + } +} + +impl Parse for PyStatement { + fn parse(input: ParseStream<'_>) -> syn::Result { + let lookahead = input.lookahead1(); + if lookahead.peek(kw::from) { + input.parse().map(Self::ImportFrom) + } else if lookahead.peek(kw::import) { + input.parse().map(Self::Import) + } else { + Err(lookahead.error()) + } + } +} + +impl ToTokens for PyStatement { + fn to_tokens(&self, tokens: &mut TokenStream) { + match self { + Self::ImportFrom(s) => s.to_tokens(tokens), + Self::Import(s) => s.to_tokens(tokens), + } + } +} + +/// `from {module} import {names}` +struct PyImportFrom { + pub from_token: kw::from, + pub module: Ident, + pub import_token: kw::import, + pub names: Punctuated, +} + +impl PyImportFrom { + pub fn as_json(&self) -> JsonValue { + JsonValue::Object(HashMap::from([ + ("type", JsonValue::String("importfrom".into())), + ( + "module", + JsonValue::String(self.module.unraw().to_string().into()), + ), + ( + "names", + JsonValue::Array(self.names.iter().map(|i| i.as_json()).collect()), + ), + ("level", JsonValue::Number(0)), + ])) + } +} + +impl Parse for PyImportFrom { + fn parse(input: ParseStream<'_>) -> syn::Result { + Ok(Self { + from_token: input.parse()?, + module: input.parse()?, + import_token: input.parse()?, + names: Punctuated::parse_separated_nonempty(input)?, + }) + } +} + +impl ToTokens for PyImportFrom { + fn to_tokens(&self, tokens: &mut TokenStream) { + self.from_token.to_tokens(tokens); + self.module.to_tokens(tokens); + self.import_token.to_tokens(tokens); + self.names.to_tokens(tokens); + } +} + +/// `import {names}` +struct PyImport { + pub import_token: kw::import, + pub names: Punctuated, +} + +impl PyImport { + pub fn as_json(&self) -> JsonValue { + JsonValue::Object(HashMap::from([ + ("type", JsonValue::String("import".into())), + ( + "names", + JsonValue::Array(self.names.iter().map(|i| i.as_json()).collect()), + ), + ])) + } +} + +impl Parse for PyImport { + fn parse(input: ParseStream<'_>) -> syn::Result { + Ok(Self { + import_token: input.parse()?, + names: Punctuated::parse_separated_nonempty(input)?, + }) + } +} + +impl ToTokens for PyImport { + fn to_tokens(&self, tokens: &mut TokenStream) { + self.import_token.to_tokens(tokens); + self.names.to_tokens(tokens); + } +} + +/// `{name} [as {as_name}]` +struct PyAlias { + pub name: Ident, + pub as_name: Option, +} + +impl PyAlias { + pub fn as_json(&self) -> JsonValue { + let mut args = HashMap::from([ + ("type", JsonValue::String("alias".into())), + ( + "name", + JsonValue::String(self.name.unraw().to_string().into()), + ), + ]); + if let Some(as_name) = &self.as_name { + args.insert( + "asname", + JsonValue::String(as_name.name.unraw().to_string().into()), + ); + } + JsonValue::Object(args) + } +} + +impl Parse for PyAlias { + fn parse(input: ParseStream<'_>) -> syn::Result { + Ok(Self { + name: input.parse()?, + as_name: if input.lookahead1().peek(Token![as]) { + Some(input.parse()?) + } else { + None + }, + }) + } +} + +impl ToTokens for PyAlias { + fn to_tokens(&self, tokens: &mut TokenStream) { + self.name.to_tokens(tokens); + self.as_name.to_tokens(tokens); + } +} + +/// `as {name}` +struct PyAliasAsName { + pub as_token: Token![as], + pub name: Ident, +} + +impl Parse for PyAliasAsName { + fn parse(input: ParseStream<'_>) -> syn::Result { + Ok(Self { + as_token: input.parse()?, + name: input.parse()?, + }) + } +} + +impl ToTokens for PyAliasAsName { + fn to_tokens(&self, tokens: &mut TokenStream) { + self.as_token.to_tokens(tokens); + self.name.to_tokens(tokens); + } +} diff --git a/pytests/src/annotations.rs b/pytests/src/annotations.rs new file mode 100644 index 00000000000..1abe4fc77d1 --- /dev/null +++ b/pytests/src/annotations.rs @@ -0,0 +1,30 @@ +//! Example of custom annotations. + +use pyo3::prelude::*; + +#[pymodule(stubs = { + from datetime import datetime as dt, time; + from uuid import UUID; +})] +pub mod annotations { + use pyo3::prelude::*; + use pyo3::types::{PyDate, PyDateTime, PyDict, PyTime, PyTuple}; + + #[pyfunction(signature = (a: "dt | time | UUID", *_args: "str", _b: "int | None" = None, **_kwargs: "bool") -> "int")] + fn with_custom_type_annotations<'py>( + a: Bound<'py, PyAny>, + _args: Bound<'py, PyTuple>, + _b: Option>, + _kwargs: Option>, + ) -> Bound<'py, PyAny> { + a + } + + #[pyfunction] + fn with_built_in_type_annotations( + _date_time: Bound<'_, PyDateTime>, + _time: Bound<'_, PyTime>, + _date: Bound<'_, PyDate>, + ) { + } +} diff --git a/pytests/src/lib.rs b/pytests/src/lib.rs index f6f4b151e6e..00b8d4f8e21 100644 --- a/pytests/src/lib.rs +++ b/pytests/src/lib.rs @@ -1,6 +1,8 @@ use pyo3::prelude::*; use pyo3::types::PyDict; +#[cfg(feature = "experimental-inspect")] +mod annotations; mod awaitable; mod buf_and_str; mod comparisons; @@ -32,6 +34,10 @@ mod pyo3_pytests { #[pymodule_export] use datetime::datetime; + #[cfg(feature = "experimental-inspect")] + #[pymodule_export] + use annotations::annotations; + #[pymodule_export] use { awaitable::awaitable, comparisons::comparisons, consts::consts, dict_iter::dict_iter, diff --git a/pytests/src/pyfunctions.rs b/pytests/src/pyfunctions.rs index 6e1015e7627..e1ffb444cac 100644 --- a/pytests/src/pyfunctions.rs +++ b/pytests/src/pyfunctions.rs @@ -77,17 +77,6 @@ fn with_typed_args(a: bool, b: u64, c: f64, d: &str) -> (bool, u64, f64, &str) { (a, b, c, d) } -#[cfg(feature = "experimental-inspect")] -#[pyfunction(signature = (a: "int", *_args: "str", _b: "int | None" = None, **_kwargs: "bool") -> "int")] -fn with_custom_type_annotations<'py>( - a: Any<'py>, - _args: Tuple<'py>, - _b: Option>, - _kwargs: Option>, -) -> Any<'py> { - a -} - #[cfg(feature = "experimental-async")] #[pyfunction] async fn with_async() {} @@ -143,9 +132,6 @@ pub mod pyfunctions { #[cfg(feature = "experimental-async")] #[pymodule_export] use super::with_async; - #[cfg(feature = "experimental-inspect")] - #[pymodule_export] - use super::with_custom_type_annotations; #[pymodule_export] use super::{ args_kwargs, many_keyword_arguments, none, positional_only, simple, simple_args, diff --git a/pytests/stubs/annotations.pyi b/pytests/stubs/annotations.pyi new file mode 100644 index 00000000000..e7e3fb32c97 --- /dev/null +++ b/pytests/stubs/annotations.pyi @@ -0,0 +1,9 @@ +from datetime import date, datetime as dt, time +from uuid import UUID + +def with_built_in_type_annotations( + _date_time: dt, _time: time, _date: date +) -> None: ... +def with_custom_type_annotations( + a: "dt | time | UUID", *_args: "str", _b: "int | None" = None, **_kwargs: "bool" +) -> "int": ... diff --git a/pytests/stubs/pyfunctions.pyi b/pytests/stubs/pyfunctions.pyi index 1d5cca9a33d..322f2642339 100644 --- a/pytests/stubs/pyfunctions.pyi +++ b/pytests/stubs/pyfunctions.pyi @@ -35,9 +35,6 @@ def simple_kwargs( a: Any, b: Any | None = None, c: Any | None = None, **kwargs ) -> tuple[Any, Any | None, Any | None, dict | None]: ... async def with_async() -> None: ... -def with_custom_type_annotations( - a: "int", *_args: "str", _b: "int | None" = None, **_kwargs: "bool" -) -> "int": ... def with_typed_args( a: bool = False, b: int = 0, c: float = 0.0, d: str = "" ) -> tuple[bool, int, float, str]: ... diff --git a/tests/test_compile_error.rs b/tests/test_compile_error.rs index e5b646239d7..233537bcfec 100644 --- a/tests/test_compile_error.rs +++ b/tests/test_compile_error.rs @@ -32,6 +32,7 @@ fn test_compile_errors() { t.compile_fail("tests/ui/invalid_pymethods_duplicates.rs"); t.compile_fail("tests/ui/invalid_pymethod_enum.rs"); t.compile_fail("tests/ui/invalid_pymethod_names.rs"); + #[cfg(not(feature = "experimental-inspect"))] t.compile_fail("tests/ui/invalid_pymodule_args.rs"); t.compile_fail("tests/ui/invalid_pycallargs.rs"); t.compile_fail("tests/ui/reject_generics.rs"); From b1b6b5400164b6802dc2f85846a52786bdbeb7d6 Mon Sep 17 00:00:00 2001 From: Thomas Pellissier-Tanon Date: Tue, 14 Apr 2026 11:32:38 +0200 Subject: [PATCH 2/4] Remove `;` from the statements syntax --- pyo3-macros-backend/src/py_stubs.rs | 19 ++++++++++++++----- pytests/src/annotations.rs | 4 ++-- 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/pyo3-macros-backend/src/py_stubs.rs b/pyo3-macros-backend/src/py_stubs.rs index f58cf405dbf..a26fb6d64de 100644 --- a/pyo3-macros-backend/src/py_stubs.rs +++ b/pyo3-macros-backend/src/py_stubs.rs @@ -18,13 +18,13 @@ mod kw { /// Custom provided stubs in #[pymodule] pub struct PyStubs { bracket_token: Brace, - imports: Punctuated, + statements: Vec, } impl PyStubs { /// Returns a JSON object following the https://docs.python.org/fr/3/library/ast.html syntax tree pub fn as_json(&self) -> JsonValue { - JsonValue::Array(self.imports.iter().map(|i| i.as_json()).collect()) + JsonValue::Array(self.statements.iter().map(|i| i.as_json()).collect()) } } @@ -33,15 +33,24 @@ impl Parse for PyStubs { let content; Ok(Self { bracket_token: braced!(content in input), - imports: content.parse_terminated(PyStatement::parse, Token![;])?, + statements: { + let mut statements = Vec::new(); + while !content.is_empty() { + statements.push(content.parse()?); + } + statements + }, }) } } impl ToTokens for PyStubs { fn to_tokens(&self, tokens: &mut TokenStream) { - self.bracket_token - .surround(tokens, |tokens| self.imports.to_tokens(tokens)) + self.bracket_token.surround(tokens, |tokens| { + for import in &self.statements { + import.to_tokens(tokens) + } + }) } } diff --git a/pytests/src/annotations.rs b/pytests/src/annotations.rs index 1abe4fc77d1..80576712c86 100644 --- a/pytests/src/annotations.rs +++ b/pytests/src/annotations.rs @@ -3,8 +3,8 @@ use pyo3::prelude::*; #[pymodule(stubs = { - from datetime import datetime as dt, time; - from uuid import UUID; + from datetime import datetime as dt, time + from uuid import UUID })] pub mod annotations { use pyo3::prelude::*; From d9eb35ad90ea478cbd7e58dd61c429182b1cf026 Mon Sep 17 00:00:00 2001 From: Thomas Pellissier-Tanon Date: Tue, 14 Apr 2026 12:11:50 +0200 Subject: [PATCH 3/4] Allow dotted names in imports --- pyo3-macros-backend/src/py_stubs.rs | 58 +++++++++++++++++++++-------- pytests/src/annotations.rs | 3 +- pytests/stubs/annotations.pyi | 6 ++- 3 files changed, 50 insertions(+), 17 deletions(-) diff --git a/pyo3-macros-backend/src/py_stubs.rs b/pyo3-macros-backend/src/py_stubs.rs index a26fb6d64de..55a6548aff0 100644 --- a/pyo3-macros-backend/src/py_stubs.rs +++ b/pyo3-macros-backend/src/py_stubs.rs @@ -4,6 +4,8 @@ use crate::json::JsonValue; use proc_macro2::{Ident, TokenStream}; use quote::ToTokens; use std::collections::HashMap; +use std::fmt; +use std::fmt::Write; use syn::ext::IdentExt; use syn::parse::{Parse, ParseStream}; use syn::punctuated::Punctuated; @@ -29,6 +31,7 @@ impl PyStubs { } impl Parse for PyStubs { + /// See https://docs.python.org/3/reference/grammar.html for the Python grammar fn parse(input: ParseStream<'_>) -> syn::Result { let content; Ok(Self { @@ -57,7 +60,7 @@ impl ToTokens for PyStubs { /// A Python statement enum PyStatement { ImportFrom(PyImportFrom), - Import(PyImport), + Import(PyImportName), } impl PyStatement { @@ -94,7 +97,7 @@ impl ToTokens for PyStatement { /// `from {module} import {names}` struct PyImportFrom { pub from_token: kw::from, - pub module: Ident, + pub module: PyDottedName, pub import_token: kw::import, pub names: Punctuated, } @@ -103,10 +106,7 @@ impl PyImportFrom { pub fn as_json(&self) -> JsonValue { JsonValue::Object(HashMap::from([ ("type", JsonValue::String("importfrom".into())), - ( - "module", - JsonValue::String(self.module.unraw().to_string().into()), - ), + ("module", JsonValue::String(self.module.to_string().into())), ( "names", JsonValue::Array(self.names.iter().map(|i| i.as_json()).collect()), @@ -137,12 +137,12 @@ impl ToTokens for PyImportFrom { } /// `import {names}` -struct PyImport { +struct PyImportName { pub import_token: kw::import, pub names: Punctuated, } -impl PyImport { +impl PyImportName { pub fn as_json(&self) -> JsonValue { JsonValue::Object(HashMap::from([ ("type", JsonValue::String("import".into())), @@ -154,7 +154,7 @@ impl PyImport { } } -impl Parse for PyImport { +impl Parse for PyImportName { fn parse(input: ParseStream<'_>) -> syn::Result { Ok(Self { import_token: input.parse()?, @@ -163,7 +163,7 @@ impl Parse for PyImport { } } -impl ToTokens for PyImport { +impl ToTokens for PyImportName { fn to_tokens(&self, tokens: &mut TokenStream) { self.import_token.to_tokens(tokens); self.names.to_tokens(tokens); @@ -172,7 +172,7 @@ impl ToTokens for PyImport { /// `{name} [as {as_name}]` struct PyAlias { - pub name: Ident, + pub name: PyDottedName, pub as_name: Option, } @@ -180,10 +180,7 @@ impl PyAlias { pub fn as_json(&self) -> JsonValue { let mut args = HashMap::from([ ("type", JsonValue::String("alias".into())), - ( - "name", - JsonValue::String(self.name.unraw().to_string().into()), - ), + ("name", JsonValue::String(self.name.to_string().into())), ]); if let Some(as_name) = &self.as_name { args.insert( @@ -236,3 +233,34 @@ impl ToTokens for PyAliasAsName { self.name.to_tokens(tokens); } } + +/// `{ident}.*` +struct PyDottedName { + pub idents: Punctuated, +} + +impl Parse for PyDottedName { + fn parse(input: ParseStream<'_>) -> syn::Result { + Ok(Self { + idents: Punctuated::parse_separated_nonempty(input)?, + }) + } +} + +impl ToTokens for PyDottedName { + fn to_tokens(&self, tokens: &mut TokenStream) { + self.idents.to_tokens(tokens); + } +} + +impl fmt::Display for PyDottedName { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + for (i, ident) in self.idents.iter().enumerate() { + if i > 0 { + f.write_char('.')?; + } + ident.unraw().fmt(f)?; + } + Ok(()) + } +} diff --git a/pytests/src/annotations.rs b/pytests/src/annotations.rs index 80576712c86..986868c8bff 100644 --- a/pytests/src/annotations.rs +++ b/pytests/src/annotations.rs @@ -5,12 +5,13 @@ use pyo3::prelude::*; #[pymodule(stubs = { from datetime import datetime as dt, time from uuid import UUID + from collections.abc import Sequence })] pub mod annotations { use pyo3::prelude::*; use pyo3::types::{PyDate, PyDateTime, PyDict, PyTime, PyTuple}; - #[pyfunction(signature = (a: "dt | time | UUID", *_args: "str", _b: "int | None" = None, **_kwargs: "bool") -> "int")] + #[pyfunction(signature = (a: "dt | time | UUID", *_args: "str", _b: "Sequence[int] | None" = None, **_kwargs: "bool") -> "int")] fn with_custom_type_annotations<'py>( a: Bound<'py, PyAny>, _args: Bound<'py, PyTuple>, diff --git a/pytests/stubs/annotations.pyi b/pytests/stubs/annotations.pyi index e7e3fb32c97..82035a1a9d8 100644 --- a/pytests/stubs/annotations.pyi +++ b/pytests/stubs/annotations.pyi @@ -1,3 +1,4 @@ +from collections.abc import Sequence from datetime import date, datetime as dt, time from uuid import UUID @@ -5,5 +6,8 @@ def with_built_in_type_annotations( _date_time: dt, _time: time, _date: date ) -> None: ... def with_custom_type_annotations( - a: "dt | time | UUID", *_args: "str", _b: "int | None" = None, **_kwargs: "bool" + a: "dt | time | UUID", + *_args: "str", + _b: "Sequence[int] | None" = None, + **_kwargs: "bool", ) -> "int": ... From a7b8a4aa1cf3738272a9ea0a482bbb5c9b0325c3 Mon Sep 17 00:00:00 2001 From: Thomas Pellissier-Tanon Date: Tue, 14 Apr 2026 12:23:00 +0200 Subject: [PATCH 4/4] Ensure statements are on a single line --- pyo3-macros-backend/Cargo.toml | 2 +- pyo3-macros-backend/src/intopyobject.rs | 1 + pyo3-macros-backend/src/py_stubs.rs | 36 ++++++++++++++++++++++--- pyo3-macros-backend/src/utils.rs | 1 + 4 files changed, 36 insertions(+), 4 deletions(-) diff --git a/pyo3-macros-backend/Cargo.toml b/pyo3-macros-backend/Cargo.toml index 3ef7c8df604..faff45c8c95 100644 --- a/pyo3-macros-backend/Cargo.toml +++ b/pyo3-macros-backend/Cargo.toml @@ -30,4 +30,4 @@ workspace = true [features] experimental-async = [] -experimental-inspect = [] +experimental-inspect = ["proc-macro2/span-locations"] diff --git a/pyo3-macros-backend/src/intopyobject.rs b/pyo3-macros-backend/src/intopyobject.rs index a4e557e4ceb..a9e21583560 100644 --- a/pyo3-macros-backend/src/intopyobject.rs +++ b/pyo3-macros-backend/src/intopyobject.rs @@ -11,6 +11,7 @@ use syn::{parse_quote, DataEnum, DeriveInput, Fields, Ident, Index, Result}; struct ItemOption(Option); +#[cfg_attr(feature = "experimental-inspect", expect(clippy::large_enum_variant))] enum IntoPyObjectTypes { Transparent(syn::Type), Opaque { diff --git a/pyo3-macros-backend/src/py_stubs.rs b/pyo3-macros-backend/src/py_stubs.rs index 55a6548aff0..493ee449bf8 100644 --- a/pyo3-macros-backend/src/py_stubs.rs +++ b/pyo3-macros-backend/src/py_stubs.rs @@ -9,6 +9,7 @@ use std::fmt::Write; use syn::ext::IdentExt; use syn::parse::{Parse, ParseStream}; use syn::punctuated::Punctuated; +use syn::spanned::Spanned; use syn::token::Brace; use syn::{braced, Token}; @@ -34,7 +35,7 @@ impl Parse for PyStubs { /// See https://docs.python.org/3/reference/grammar.html for the Python grammar fn parse(input: ParseStream<'_>) -> syn::Result { let content; - Ok(Self { + let statements = Self { bracket_token: braced!(content in input), statements: { let mut statements = Vec::new(); @@ -43,7 +44,25 @@ impl Parse for PyStubs { } statements }, - }) + }; + + // Ensure that two statements are not in the same line + for (s1, s2) in statements + .statements + .iter() + .zip(statements.statements.iter().skip(1)) + { + let s1span = s1.span(); + let s2span = s2.span(); + if s1span.end().line == s2span.start().line { + return Err(syn::Error::new( + s1span.join(s2span).unwrap(), + "Each Python statement (import...) must be on its own line", + )); + } + } + + Ok(statements) } } @@ -75,13 +94,24 @@ impl PyStatement { impl Parse for PyStatement { fn parse(input: ParseStream<'_>) -> syn::Result { let lookahead = input.lookahead1(); - if lookahead.peek(kw::from) { + let statement = if lookahead.peek(kw::from) { input.parse().map(Self::ImportFrom) } else if lookahead.peek(kw::import) { input.parse().map(Self::Import) } else { Err(lookahead.error()) + }?; + + // Ensure that the statement is on a single line + let span = statement.span(); + if span.start().line != span.end().line { + return Err(syn::Error::new( + span, + "Python statements must be on a single line", + )); } + + Ok(statement) } } diff --git a/pyo3-macros-backend/src/utils.rs b/pyo3-macros-backend/src/utils.rs index fa0dd90664b..e0d44139382 100644 --- a/pyo3-macros-backend/src/utils.rs +++ b/pyo3-macros-backend/src/utils.rs @@ -168,6 +168,7 @@ impl PythonDoc { } /// A plain string or an expression +#[cfg_attr(feature = "experimental-inspect", expect(clippy::large_enum_variant))] #[derive(Clone)] pub enum StrOrExpr { Str { value: String, span: Option },