diff --git a/lib/ruby_lsp/ruby_lsp_rails/definition.rb b/lib/ruby_lsp/ruby_lsp_rails/definition.rb index 133d7760..1c3739fe 100644 --- a/lib/ruby_lsp/ruby_lsp_rails/definition.rb +++ b/lib/ruby_lsp/ruby_lsp_rails/definition.rb @@ -80,7 +80,7 @@ def handle_possible_dsl(node) return unless arguments if Support::Associations::ALL.include?(message) - handle_association(call_node) + handle_association(node, call_node) elsif Support::Callbacks::ALL.include?(message) handle_callback(node, call_node, arguments) handle_if_unless_conditional(node, call_node, arguments) @@ -125,18 +125,57 @@ def handle_validation(node, call_node, arguments) collect_definitions(name) end - #: (Prism::CallNode node) -> void - def handle_association(node) - first_argument = node.arguments&.arguments&.first - return unless first_argument.is_a?(Prism::SymbolNode) + #: ((Prism::SymbolNode | Prism::StringNode) node, Prism::CallNode call_node) -> void + def handle_association(node, call_node) + arguments = call_node.arguments&.arguments + return unless arguments + + first_argument = arguments.first + return unless first_argument.is_a?(Prism::SymbolNode) || first_argument.is_a?(Prism::StringNode) + + association_name = extract_string_from_node(first_argument) + return unless association_name + + through_element = find_through_association_element(arguments) + clicked_symbol = extract_string_from_node(node) + return unless clicked_symbol + + if through_element + through_association_name = extract_string_from_node(through_element.value) + + if clicked_symbol == association_name + handle_association_name(association_name) + elsif through_association_name && clicked_symbol == through_association_name + handle_association_name(through_association_name) + end + else + handle_association_name(association_name) + end + end - association_name = first_argument.unescaped + #: (Array[Prism::Node]) -> Prism::AssocNode? + def find_through_association_element(arguments) + result = arguments + .filter_map { |arg| arg.elements if arg.is_a?(Prism::KeywordHashNode) } + .flatten + .find do |elem| + next false unless elem.is_a?(Prism::AssocNode) + key = elem.key + next false unless key.is_a?(Prism::SymbolNode) + + key.value == "through" + end + + result if result.is_a?(Prism::AssocNode) + end + + #: (String association_name) -> void + def handle_association_name(association_name) result = @client.association_target( model_name: @nesting.join("::"), association_name: association_name, ) - return unless result @response_builder << Support::LocationBuilder.line_location_from_s(result.fetch(:location)) @@ -194,6 +233,16 @@ def handle_if_unless_conditional(node, call_node, arguments) collect_definitions(method_name) end + + #: (Prism::Node) -> String? + def extract_string_from_node(node) + case node + when Prism::SymbolNode + node.unescaped + when Prism::StringNode + node.content + end + end end end end diff --git a/lib/ruby_lsp/ruby_lsp_rails/hover.rb b/lib/ruby_lsp/ruby_lsp_rails/hover.rb index 2bd35379..0b5b489b 100644 --- a/lib/ruby_lsp/ruby_lsp_rails/hover.rb +++ b/lib/ruby_lsp/ruby_lsp_rails/hover.rb @@ -29,6 +29,7 @@ def initialize(client, response_builder, node_context, global_state, dispatcher) :on_constant_path_node_enter, :on_constant_read_node_enter, :on_symbol_node_enter, + :on_string_node_enter, ) end @@ -56,6 +57,11 @@ def on_symbol_node_enter(node) handle_possible_dsl(node) end + #: (Prism::StringNode node) -> void + def on_string_node_enter(node) + handle_possible_dsl(node) + end + private #: (String name) -> void @@ -116,28 +122,67 @@ def format_default(default_value, type) end end - #: (Prism::SymbolNode node) -> void + #: ((Prism::SymbolNode | Prism::StringNode) node) -> void def handle_possible_dsl(node) - node = @node_context.call_node - return unless node - return unless self_receiver?(node) - - message = node.message + call_node = @node_context.call_node + return unless call_node + return unless self_receiver?(call_node) + message = call_node.message return unless message if Support::Associations::ALL.include?(message) - handle_association(node) + handle_association(node, call_node) + end + end + + #: ((Prism::SymbolNode | Prism::StringNode) node, Prism::CallNode call_node) -> void + def handle_association(node, call_node) + arguments = call_node.arguments&.arguments + return unless arguments + + first_argument = arguments.first + return unless first_argument.is_a?(Prism::SymbolNode) || first_argument.is_a?(Prism::StringNode) + + association_name = extract_string_from_node(first_argument) + return unless association_name + + through_element = find_through_association_element(arguments) + clicked_symbol = extract_string_from_node(node) + return unless clicked_symbol + + if through_element + through_association_name = extract_string_from_node(through_element.value) + + if clicked_symbol == association_name + handle_association_name(association_name) + elsif through_association_name && clicked_symbol == through_association_name + handle_association_name(through_association_name) + end + else + handle_association_name(association_name) end end - #: (Prism::CallNode node) -> void - def handle_association(node) - first_argument = node.arguments&.arguments&.first - return unless first_argument.is_a?(Prism::SymbolNode) + #: (Array[Prism::Node]) -> Prism::AssocNode? + def find_through_association_element(arguments) + result = arguments + .filter_map { |arg| arg.elements if arg.is_a?(Prism::KeywordHashNode) } + .flatten + .find do |elem| + next false unless elem.is_a?(Prism::AssocNode) + + key = elem.key + next false unless key.is_a?(Prism::SymbolNode) - association_name = first_argument.unescaped + key.value == "through" + end + result if result.is_a?(Prism::AssocNode) + end + + #: (String association_name) -> void + def handle_association_name(association_name) result = @client.association_target( model_name: @nesting.join("::"), association_name: association_name, @@ -164,6 +209,16 @@ def generate_hover(name) @response_builder.push(content, category: category) end end + + #: (Prism::Node) -> String? + def extract_string_from_node(node) + case node + when Prism::SymbolNode + node.unescaped + when Prism::StringNode + node.content + end + end end end end diff --git a/test/dummy/app/models/country.rb b/test/dummy/app/models/country.rb index 8a28ba43..2e8dc6ff 100644 --- a/test/dummy/app/models/country.rb +++ b/test/dummy/app/models/country.rb @@ -1,4 +1,5 @@ # frozen_string_literal: true class Country < ApplicationRecord + has_one :flag, dependent: :destroy end diff --git a/test/dummy/app/models/flag.rb b/test/dummy/app/models/flag.rb new file mode 100644 index 00000000..b4970866 --- /dev/null +++ b/test/dummy/app/models/flag.rb @@ -0,0 +1,5 @@ +# frozen_string_literal: true + +class Flag < ApplicationRecord + belongs_to :country +end diff --git a/test/dummy/app/models/user.rb b/test/dummy/app/models/user.rb index c696c86d..4124b35c 100644 --- a/test/dummy/app/models/user.rb +++ b/test/dummy/app/models/user.rb @@ -5,7 +5,8 @@ class User < ApplicationRecord validates :first_name, presence: true has_one :profile scope :adult, -> { where(age: 18..) } - has_one :location, class_name: "Country" + belongs_to :location, class_name: "Country" + has_one :country_flag, through: :location, source: :flag attr_readonly :last_name diff --git a/test/dummy/db/migrate/20250703132109_create_flags.rb b/test/dummy/db/migrate/20250703132109_create_flags.rb new file mode 100644 index 00000000..23ffa324 --- /dev/null +++ b/test/dummy/db/migrate/20250703132109_create_flags.rb @@ -0,0 +1,9 @@ +class CreateFlags < ActiveRecord::Migration[8.0] + def change + create_table :flags do |t| + t.references :country, null: false, foreign_key: true + + t.timestamps + end + end +end diff --git a/test/dummy/db/schema.rb b/test/dummy/db/schema.rb index fa61c344..f103abac 100644 --- a/test/dummy/db/schema.rb +++ b/test/dummy/db/schema.rb @@ -10,7 +10,7 @@ # # It's strongly recommended that you check this file into your version control system. -ActiveRecord::Schema[8.0].define(version: 2024_10_25_225348) do +ActiveRecord::Schema[8.0].define(version: 2025_07_03_132109) do create_table "composite_primary_keys", primary_key: ["order_id", "product_id"], force: :cascade do |t| t.integer "order_id" t.integer "product_id" @@ -25,6 +25,13 @@ t.datetime "updated_at", null: false end + create_table "flags", force: :cascade do |t| + t.integer "country_id", null: false + t.datetime "created_at", null: false + t.datetime "updated_at", null: false + t.index ["country_id"], name: "index_flags_on_country_id" + end + create_table "memberships", force: :cascade do |t| t.integer "user_id", null: false t.integer "organization_id", null: false @@ -58,6 +65,7 @@ t.index ["country_id"], name: "index_users_on_country_id" end + add_foreign_key "flags", "countries" add_foreign_key "memberships", "organizations" add_foreign_key "memberships", "users" add_foreign_key "users", "countries" diff --git a/test/ruby_lsp_rails/definition_test.rb b/test/ruby_lsp_rails/definition_test.rb index bc8c03e0..e42a1bc1 100644 --- a/test/ruby_lsp_rails/definition_test.rb +++ b/test/ruby_lsp_rails/definition_test.rb @@ -53,6 +53,42 @@ class Organization < ActiveRecord::Base assert_equal(2, response[0].range.end.line) end + test "recognizes main association on has_many :through association" do + response = generate_definitions_for_source(<<~RUBY, { line: 2, character: 13 }) + class Organization < ActiveRecord::Base + has_many :memberships + has_many :users, through: :memberships + end + RUBY + + assert_equal(1, response.size) + + assert_equal( + URI::Generic.from_path(path: File.join(dummy_root, "app", "models", "user.rb")).to_s, + response[0].uri, + ) + assert_equal(2, response[0].range.start.line) + assert_equal(2, response[0].range.end.line) + end + + test "recognizes through association on has_many :through association" do + response = generate_definitions_for_source(<<~RUBY, { line: 2, character: 30 }) + class Organization < ActiveRecord::Base + has_many :memberships + has_many :users, through: :memberships + end + RUBY + + assert_equal(1, response.size) + + assert_equal( + URI::Generic.from_path(path: File.join(dummy_root, "app", "models", "membership.rb")).to_s, + response[0].uri, + ) + assert_equal(2, response[0].range.start.line) + assert_equal(2, response[0].range.end.line) + end + test "recognizes belongs_to model associations" do response = generate_definitions_for_source(<<~RUBY, { line: 3, character: 14 }) # typed: false @@ -91,6 +127,42 @@ class User < ActiveRecord::Base assert_equal(2, response[0].range.end.line) end + test "recognizes main association on has_one :through association" do + response = generate_definitions_for_source(<<~RUBY, { line: 2, character: 12 }) + class User < ActiveRecord::Base + belongs_to :location, class_name: "Country" + has_one :country_flag, through: :location, source: :flag + end + RUBY + + assert_equal(1, response.size) + + assert_equal( + URI::Generic.from_path(path: File.join(dummy_root, "app", "models", "flag.rb")).to_s, + response[0].uri, + ) + assert_equal(2, response[0].range.start.line) + assert_equal(2, response[0].range.end.line) + end + + test "recognizes through association on has_one :through association" do + response = generate_definitions_for_source(<<~RUBY, { line: 2, character: 36 }) + class User < ActiveRecord::Base + belongs_to :location, class_name: "Country" + has_one :country_flag, through: :location, source: :flag + end + RUBY + + assert_equal(1, response.size) + + assert_equal( + URI::Generic.from_path(path: File.join(dummy_root, "app", "models", "country.rb")).to_s, + response[0].uri, + ) + assert_equal(2, response[0].range.start.line) + assert_equal(2, response[0].range.end.line) + end + test "recognizes has_and_belongs_to_many model associations" do response = generate_definitions_for_source(<<~RUBY, { line: 3, character: 27 }) # typed: false @@ -111,11 +183,11 @@ class Profile < ActiveRecord::Base end test "handles class_name argument for associations" do - response = generate_definitions_for_source(<<~RUBY, { line: 3, character: 11 }) + response = generate_definitions_for_source(<<~RUBY, { line: 3, character: 14 }) # typed: false class User < ActiveRecord::Base - has_one :location, class_name: "Country" + belongs_to :location, class_name: "Country" end RUBY @@ -467,6 +539,48 @@ def name; end assert_equal(15, response.range.end.character) end + test "recognizes string main association on has_many :through association" do + response = generate_definitions_for_source(<<~RUBY, { line: 2, character: 14 }) + class Organization < ApplicationRecord + has_many :memberships + has_many "users", through: :memberships + end + + class User < ApplicationRecord + end + RUBY + + assert_equal(1, response.size) + + assert_equal( + URI::Generic.from_path(path: File.join(dummy_root, "app", "models", "user.rb")).to_s, + response[0].uri, + ) + assert_equal(2, response[0].range.start.line) + assert_equal(2, response[0].range.end.line) + end + + test "recognizes string through association on has_many :through association" do + response = generate_definitions_for_source(<<~RUBY, { line: 2, character: 32 }) + class Organization < ApplicationRecord + has_many "memberships" + has_many :users, through: "memberships" + end + + class Membership < ApplicationRecord + end + RUBY + + assert_equal(1, response.size) + + assert_equal( + URI::Generic.from_path(path: File.join(dummy_root, "app", "models", "membership.rb")).to_s, + response[0].uri, + ) + assert_equal(2, response[0].range.start.line) + assert_equal(2, response[0].range.end.line) + end + private def generate_definitions_for_source(source, position) diff --git a/test/ruby_lsp_rails/hover_test.rb b/test/ruby_lsp_rails/hover_test.rb index e9bca9e4..19577191 100644 --- a/test/ruby_lsp_rails/hover_test.rb +++ b/test/ruby_lsp_rails/hover_test.rb @@ -321,6 +321,162 @@ class Bar < ApplicationRecord CONTENT end + test "returns main association on has_many :through association" do + expected_response = { + location: "#{dummy_root}/app/models/user.rb:2", + name: "User", + } + RunnerClient.any_instance.stubs(association_target: expected_response) + + response = hover_on_source(<<~RUBY, { line: 2, character: 14 }) + class Organization < ApplicationRecord + has_many :memberships + has_many :users, through: :memberships + end + + class User < ApplicationRecord + end + RUBY + + assert_equal(<<~CONTENT.chomp, response.contents.value) + ```ruby + User + ``` + + **Definitions**: [fake.rb](file:///fake.rb#L6,1-7,4) + CONTENT + end + + test "returns through association on has_many :through association" do + expected_response = { + location: "#{dummy_root}/app/models/membership.rb:3", + name: "Membership", + } + RunnerClient.any_instance.stubs(association_target: expected_response) + + response = hover_on_source(<<~RUBY, { line: 2, character: 31 }) + class Organization < ApplicationRecord + has_many :memberships + has_many :users, through: :memberships + end + + class Membership < ApplicationRecord + end + RUBY + + assert_equal(<<~CONTENT.chomp, response.contents.value) + ```ruby + Membership + ``` + + **Definitions**: [fake.rb](file:///fake.rb#L6,1-7,4) + CONTENT + end + + test "returns main association on has_one :through association" do + expected_response = { + location: "#{dummy_root}/app/models/flag.rb:2", + name: "Flag", + } + RunnerClient.any_instance.stubs(association_target: expected_response) + + response = hover_on_source(<<~RUBY, { line: 2, character: 13 }) + class User < ApplicationRecord + belongs_to :location, class_name: "Country" + has_one :country_flag, through: :location, source: :flag + end + + class Flag < ApplicationRecord + end + RUBY + + assert_equal(<<~CONTENT.chomp, response.contents.value) + ```ruby + Flag + ``` + + **Definitions**: [fake.rb](file:///fake.rb#L6,1-7,4) + CONTENT + end + + test "returns through association on has_one :through association" do + expected_response = { + location: "#{dummy_root}/app/models/country.rb:2", + name: "Country", + } + RunnerClient.any_instance.stubs(association_target: expected_response) + + response = hover_on_source(<<~RUBY, { line: 2, character: 37 }) + class User < ApplicationRecord + belongs_to :location, class_name: "Country" + has_one :country_flag, through: :location, source: :flag + end + + class Country < ApplicationRecord + end + RUBY + + assert_equal(<<~CONTENT.chomp, response.contents.value) + ```ruby + Country + ``` + + **Definitions**: [fake.rb](file:///fake.rb#L6,1-7,4) + CONTENT + end + + test "returns string main association on has_many :through association" do + expected_response = { + location: "#{dummy_root}/app/models/user.rb:2", + name: "User", + } + RunnerClient.any_instance.stubs(association_target: expected_response) + + response = hover_on_source(<<~RUBY, { line: 2, character: 14 }) + class Organization < ApplicationRecord + has_many :memberships + has_many "users", through: :memberships + end + + class User < ApplicationRecord + end + RUBY + + assert_equal(<<~CONTENT.chomp, response.contents.value) + ```ruby + User + ``` + + **Definitions**: [fake.rb](file:///fake.rb#L6,1-7,4) + CONTENT + end + + test "returns string through association on has_many :through association" do + expected_response = { + location: "#{dummy_root}/app/models/membership.rb:3", + name: "Membership", + } + RunnerClient.any_instance.stubs(association_target: expected_response) + + response = hover_on_source(<<~RUBY, { line: 2, character: 32 }) + class Organization < ApplicationRecord + has_many "memberships" + has_many :users, through: "memberships" + end + + class Membership < ApplicationRecord + end + RUBY + + assert_equal(<<~CONTENT.chomp, response.contents.value) + ```ruby + Membership + ``` + + **Definitions**: [fake.rb](file:///fake.rb#L6,1-7,4) + CONTENT + end + private def hover_on_source(source, position)