Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 73 additions & 0 deletions rust/ruby-rbs/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,30 @@ impl NodeField {
}
}

#[derive(Debug, Deserialize)]
struct LocationField {
#[serde(default)]
required: Option<String>,
#[serde(default)]
optional: Option<String>,
}

impl LocationField {
fn name(&self) -> &str {
self.required.as_ref().or(self.optional.as_ref()).unwrap()
}

fn is_required(&self) -> bool {
self.required.is_some()
}
}

#[derive(Debug, Deserialize)]
struct Node {
name: String,
rust_name: String,
fields: Option<Vec<NodeField>>,
locations: Option<Vec<LocationField>>,
}

impl Node {
Expand Down Expand Up @@ -72,6 +91,7 @@ fn main() -> Result<(), Box<dyn Error>> {
name: "RBS::AST::Symbol".to_string(),
rust_name: "SymbolNode".to_string(),
fields: None,
locations: None,
});

config.nodes.sort_by(|a, b| a.name.cmp(&b.name));
Expand Down Expand Up @@ -487,6 +507,59 @@ fn generate(config: &Config) -> Result<(), Box<dyn Error>> {
writeln!(file, " }}")?;
writeln!(file)?;

// Generate location accessor methods
if let Some(locations) = &node.locations {
for location in locations {
let location_name = location.name();
let method_name = format!("{}_location", location_name);
let field_name = format!("{}_range", location_name);

if location.is_required() {
writeln!(
file,
" /// Returns the `{}` sub-location of this node.",
location_name
)?;
writeln!(file, " #[must_use]")?;
writeln!(
file,
" pub fn {}(&self) -> RBSLocationRange {{",
method_name
)?;
writeln!(
file,
" RBSLocationRange::new(unsafe {{ (*self.pointer).{} }})",
field_name
)?;
writeln!(file, " }}")?;
} else {
writeln!(
file,
" /// Returns the `{}` sub-location of this node if present.",
location_name
)?;
writeln!(file, " #[must_use]")?;
writeln!(
file,
" pub fn {}(&self) -> Option<RBSLocationRange> {{",
method_name
)?;
writeln!(
file,
" let range = unsafe {{ (*self.pointer).{} }};",
field_name
)?;
writeln!(file, " if range.start_char == -1 {{")?;
writeln!(file, " None")?;
writeln!(file, " }} else {{")?;
writeln!(file, " Some(RBSLocationRange::new(range))")?;
writeln!(file, " }}")?;
writeln!(file, " }}")?;
}
writeln!(file)?;
}
}

if let Some(fields) = &node.fields {
for field in fields {
match field.c_type.as_str() {
Expand Down
60 changes: 60 additions & 0 deletions rust/ruby-rbs/examples/locations.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
use ruby_rbs::node::{Node, parse};

fn main() {
let rbs_code = r#"class Foo[T] < Bar end"#;
let signature = parse(rbs_code.as_bytes()).unwrap();

let declaration = signature.declarations().iter().next().unwrap();
if let Node::Class(class) = declaration {
println!("Class declaration: '{}'", rbs_code);
println!(
"Overall location: {}..{}",
class.location().start(),
class.location().end()
);

// Required sub-locations
let keyword = class.keyword_location();
println!(
" keyword location: {}..{} = '{}'",
keyword.start(),
keyword.end(),
&rbs_code[keyword.start() as usize..keyword.end() as usize]
);

let name = class.name_location();
println!(
" name location: {}..{} = '{}'",
name.start(),
name.end(),
&rbs_code[name.start() as usize..name.end() as usize]
);

let end_loc = class.end_location();
println!(
" end location: {}..{} = '{}'",
end_loc.start(),
end_loc.end(),
&rbs_code[end_loc.start() as usize..end_loc.end() as usize]
);

// Optional sub-locations
if let Some(type_params) = class.type_params_location() {
println!(
" type_params location: {}..{} = '{}'",
type_params.start(),
type_params.end(),
&rbs_code[type_params.start() as usize..type_params.end() as usize]
);
}

if let Some(lt) = class.lt_location() {
println!(
" lt location: {}..{} = '{}'",
lt.start(),
lt.end(),
&rbs_code[lt.start() as usize..lt.end() as usize]
);
}
}
}
106 changes: 106 additions & 0 deletions rust/ruby-rbs/src/node/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,112 @@ mod tests {
assert_eq!(12, int_loc.end());
}

#[test]
fn test_sub_locations() {
let rbs_code = r#"class Foo < Bar end"#;
let signature = parse(rbs_code.as_bytes()).unwrap();

let declaration = signature.declarations().iter().next().unwrap();
let Node::Class(class) = declaration else {
panic!("Expected Class");
};

// Test required sub-locations
let keyword_loc = class.keyword_location();
assert_eq!(0, keyword_loc.start());
assert_eq!(5, keyword_loc.end());

let name_loc = class.name_location();
assert_eq!(6, name_loc.start());
assert_eq!(9, name_loc.end());

let end_loc = class.end_location();
assert_eq!(16, end_loc.start());
assert_eq!(19, end_loc.end());

// Test optional sub-location that's present
let lt_loc = class.lt_location();
assert!(lt_loc.is_some());
let lt = lt_loc.unwrap();
assert_eq!(10, lt.start());
assert_eq!(11, lt.end());

// Test optional sub-location that's not present (no type params in this class)
let type_params_loc = class.type_params_location();
assert!(type_params_loc.is_none());
}

#[test]
fn test_type_alias_sub_locations() {
let rbs_code = r#"type foo = String"#;
let signature = parse(rbs_code.as_bytes()).unwrap();

let declaration = signature.declarations().iter().next().unwrap();
let Node::TypeAlias(type_alias) = declaration else {
panic!("Expected TypeAlias");
};

// Test required sub-locations
let keyword_loc = type_alias.keyword_location();
assert_eq!(0, keyword_loc.start());
assert_eq!(4, keyword_loc.end());

let name_loc = type_alias.name_location();
assert_eq!(5, name_loc.start());
assert_eq!(8, name_loc.end());

let eq_loc = type_alias.eq_location();
assert_eq!(9, eq_loc.start());
assert_eq!(10, eq_loc.end());

// Test optional sub-location that's not present (no type params)
let type_params_loc = type_alias.type_params_location();
assert!(type_params_loc.is_none());
}

#[test]
fn test_module_sub_locations() {
let rbs_code = r#"module Foo[T] : Bar end"#;
let signature = parse(rbs_code.as_bytes()).unwrap();

let declaration = signature.declarations().iter().next().unwrap();
let Node::Module(module) = declaration else {
panic!("Expected Module");
};

// Test required sub-locations
let keyword_loc = module.keyword_location();
assert_eq!(0, keyword_loc.start());
assert_eq!(6, keyword_loc.end());

let name_loc = module.name_location();
assert_eq!(7, name_loc.start());
assert_eq!(10, name_loc.end());

let end_loc = module.end_location();
assert_eq!(20, end_loc.start());
assert_eq!(23, end_loc.end());

// Test optional sub-locations that are present
let type_params_loc = module.type_params_location();
assert!(type_params_loc.is_some());
let tp = type_params_loc.unwrap();
assert_eq!(10, tp.start());
assert_eq!(13, tp.end());

let colon_loc = module.colon_location();
assert!(colon_loc.is_some());
let colon = colon_loc.unwrap();
assert_eq!(14, colon.start());
assert_eq!(15, colon.end());

let self_types_loc = module.self_types_location();
assert!(self_types_loc.is_some());
let st = self_types_loc.unwrap();
assert_eq!(16, st.start());
assert_eq!(19, st.end());
}

#[test]
fn test_enum_types() {
let rbs_code = r#"
Expand Down