diff --git a/Changelog.md b/Changelog.md index 6bd2cff4..02be5a28 100644 --- a/Changelog.md +++ b/Changelog.md @@ -6,6 +6,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +## v3.22.0 + +### Fixed + +Make connection switching thread safe, by fixing a thread safety issue caused by using a (class) instance variable instead of a thread-local variable. + ## v3.21.0 ### Added diff --git a/active_record_shards.gemspec b/active_record_shards.gemspec index 0f342c34..3aef9ed0 100644 --- a/active_record_shards.gemspec +++ b/active_record_shards.gemspec @@ -1,4 +1,4 @@ -Gem::Specification.new "active_record_shards", "3.21.0" do |s| +Gem::Specification.new "active_record_shards", "3.22.0" do |s| s.authors = ["Benjamin Quorning", "Gabe Martin-Dempesy", "Pierre Schambacher", "Mick Staugaard", "Eric Chapweske", "Ben Osheroff"] s.email = ["bquorning@zendesk.com", "gabe@zendesk.com", "pschambacher@zendesk.com", "mick@staugaard.com"] s.homepage = "https://github.com/zendesk/active_record_shards" diff --git a/lib/active_record_shards/connection_switcher.rb b/lib/active_record_shards/connection_switcher.rb index ccb2cdca..86496664 100644 --- a/lib/active_record_shards/connection_switcher.rb +++ b/lib/active_record_shards/connection_switcher.rb @@ -113,12 +113,11 @@ def on_primary(&block) alias_method :with_slave_unless, :on_replica_unless def on_cx_switch_block(which, force: false, construct_ro_scope: nil, &block) - @disallow_replica ||= 0 - @disallow_replica += 1 if [:primary, :master].include?(which) + self.disallow_replica += 1 if [:primary, :master].include?(which) ActiveRecordShards::Deprecation.warn('the `:master` option should be replaced with `:primary`!') if which == :master - switch_to_replica = force || @disallow_replica.zero? + switch_to_replica = force || disallow_replica.zero? old_options = current_shard_selection.options switch_connection(replica: switch_to_replica) @@ -131,10 +130,18 @@ def on_cx_switch_block(which, force: false, construct_ro_scope: nil, &block) readonly.scoping(&block) end ensure - @disallow_replica -= 1 if [:primary, :master].include?(which) + self.disallow_replica -= 1 if [:primary, :master].include?(which) switch_connection(old_options) if old_options end + def disallow_replica=(value) + Thread.current[:__active_record_shards__disallow_replica_by_thread] = value + end + + def disallow_replica + Thread.current[:__active_record_shards__disallow_replica_by_thread] ||= 0 + end + def supports_sharding? shard_names.any? end diff --git a/test/database.yml b/test/database.yml index 07386391..cf8a4f4e 100644 --- a/test/database.yml +++ b/test/database.yml @@ -8,6 +8,7 @@ mysql: &MYSQL port: <%= mysql.port %> password: <%= mysql.password %> ssl_mode: :disabled + reaping_frequency: 0 # Prevents ActiveRecord from spawning reaping threads. # We connect to the unsharded primary database on a different port, via a proxy, # so we can make the connection unavailable when testing on_replica_by_default diff --git a/test/helper.rb b/test/helper.rb index ce44b602..37117907 100644 --- a/test/helper.rb +++ b/test/helper.rb @@ -1,5 +1,9 @@ # frozen_string_literal: true +# Stop Minitest creating threads we won't use. +# They add noise to the thread safety tests when inspecting `Thread.list`. +ENV["MT_CPU"] ||= "1" + require 'bundler/setup' require 'minitest/autorun' require 'minitest/rg' diff --git a/test/thread_safety_test.rb b/test/thread_safety_test.rb new file mode 100644 index 00000000..1c0467cb --- /dev/null +++ b/test/thread_safety_test.rb @@ -0,0 +1,273 @@ +# frozen_string_literal: true + +require_relative 'helper' +require_relative 'models' + +describe "connection switching thread safety" do + with_fresh_databases + + before do + ActiveRecord::Base.establish_connection(:test) + use_same_connection_handler_for_all_theads + create_seed_data + end + + after do + ActiveRecord::Base.connection_handler.clear_all_connections! + end + + it "can safely switch between all database connections in parallel" do + new_thread("switches_through_all_1") do + pause_and_mark_ready + switch_through_all_databases + end + new_thread("switches_through_all_2") do + pause_and_mark_ready + switch_through_all_databases + end + new_thread("switches_through_all_3") do + pause_and_mark_ready + switch_through_all_databases + end + + wait_for_threads_to_be_ready + execute_and_wait_for_threads + end + + describe "when multiple threads use different databases" do + it "allows threads to parallelize their IO" do + results = [] + + query_delay = { fast: "0.01", slow: "1", medium: "0.5" } + new_thread("different_db_parallel_thread1") do + ActiveRecord::Base.on_primary do + pause_and_mark_ready + result = execute_sql("SELECT name,'slower query',SLEEP(#{query_delay.fetch(:slow)}) FROM accounts") + assert_equal('Primary account', result.first[0]) + results.push(result) + end + end + + new_thread("different_db_parallel_thread2") do + ActiveRecord::Base.on_replica do + pause_and_mark_ready + result = execute_sql("SELECT name, 'faster query',SLEEP(#{query_delay.fetch(:fast)}) FROM accounts") + assert_equal('Replica account', result.first[0]) + results.push(result) + end + end + + new_thread("different_db_parallel_thread3") do + ActiveRecord::Base.on_shard(0) do + pause_and_mark_ready + result = execute_sql("SELECT title, 'medium query',SLEEP(#{query_delay.fetch(:medium)}) FROM tickets") + assert_equal('Shard 0 Primary ticket', result.first[0]) + results.push(result) + end + end + + wait_for_threads_to_be_ready + + thread_exection_time = Benchmark.realtime do + execute_and_wait_for_threads + end + + minimum_serial_query_exection_time = query_delay.values.map(&:to_f).sum + # Arbitrarily faster time such that there must have been some parallelization + max_parallel_time = minimum_serial_query_exection_time - 0.1 + assert_operator(max_parallel_time, :>, thread_exection_time) + + # This order cannot be guaranteed but it likely given the artificial delays + rows = results.map(&:first) + result_strings = rows.map { |r| r[1] } + assert_equal( + [ + "faster query", + "medium query", + "slower query" + ], + result_strings + ) + end + end + + describe "when multiple threads use the same database" do + it "exposes a different connections to each thread" do + connections = [] + + new_thread("connection_per_thread1") do + ActiveRecord::Base.on_primary do + pause_and_mark_ready + connections << ActiveRecord::Base.connection + end + end + + new_thread("connection_per_thread2") do + ActiveRecord::Base.on_primary do + pause_and_mark_ready + connections << ActiveRecord::Base.connection + end + end + + wait_for_threads_to_be_ready + execute_and_wait_for_threads + + expect(connections.first).must_be_kind_of(ActiveRecord::ConnectionAdapters::Mysql2Adapter) + assert_equal(2, connections.uniq.size) + end + + it "allows threads to parallelize their IO" do + results = [] + + query_delay = { fast: "0.01", slow: "1", medium: "0.5" } + new_thread("same_db_parallel_thread1") do + ActiveRecord::Base.on_primary do + pause_and_mark_ready + result = execute_sql("SELECT 'slower query',SLEEP(#{query_delay.fetch(:slow)})") + results.push(result) + end + end + + new_thread("same_db_parallel_thread2") do + ActiveRecord::Base.on_primary do + pause_and_mark_ready + result = execute_sql("SELECT 'faster query',SLEEP(#{query_delay.fetch(:fast)})") + results.push(result) + end + end + + new_thread("same_db_parallel_thread3") do + ActiveRecord::Base.on_primary do + pause_and_mark_ready + result = execute_sql("SELECT 'medium query',SLEEP(#{query_delay.fetch(:medium)})") + results.push(result) + end + end + + wait_for_threads_to_be_ready + + thread_exection_time = Benchmark.realtime do + execute_and_wait_for_threads + end + + minimum_serial_query_exection_time = query_delay.values.map(&:to_f).sum + # Arbitrarily faster time such that there must have been some parallelization + max_parallel_time = minimum_serial_query_exection_time - 0.1 + assert_operator(max_parallel_time, :>, thread_exection_time) + + rows = results.map(&:first) + result_strings = rows.map(&:first) + # This order cannot be guaranteed but it likely given the artificial delays + assert_equal( + [ + "faster query", + "medium query", + "slower query" + ], + result_strings + ) + end + end + + def new_thread(name) + thread = Thread.new do + Thread.current.name = name + yield + end + + @test_threads ||= [] + @test_threads.push(thread) + end + + def switch_through_all_databases + ActiveRecord::Base.on_primary do + result = ActiveRecord::Base.connection.execute("SELECT * from accounts") + assert_equal("Primary account", record_name(result)) + end + ActiveRecord::Base.on_replica do + result = ActiveRecord::Base.connection.execute("SELECT * from accounts") + assert_equal("Replica account", record_name(result)) + end + ActiveRecord::Base.on_shard(0) do + result = ActiveRecord::Base.connection.execute("SELECT * from tickets") + assert_equal("Shard 0 Primary ticket", record_name(result)) + + ActiveRecord::Base.on_replica do + result = ActiveRecord::Base.connection.execute("SELECT * from tickets") + assert_equal("Shard 0 Replica ticket", record_name(result)) + end + end + ActiveRecord::Base.on_shard(1) do + result = ActiveRecord::Base.connection.execute("SELECT * from tickets") + assert_equal("Shard 1 Primary ticket", record_name(result)) + + ActiveRecord::Base.on_replica do + result = ActiveRecord::Base.connection.execute("SELECT * from tickets") + assert_equal("Shard 1 Replica ticket", record_name(result)) + end + end + end + + # This allows us to get all of our threads into a prepared state by pausing + # them at a 'ready' point so as there is as little overhead as possible + # before the interesting code executes. + # + # Here we use 'ready' to mean the thread is spawned, has had its names set + # and has established a database connection. + def pause_and_mark_ready + Thread.current[:ready] = true + sleep + end + + def execute_and_wait_for_threads + @test_threads.each { |t| t.wakeup if t.alive? } + @test_threads.each(&:join) + end + + def wait_for_threads_to_be_ready + sleep(0.01) until @test_threads.all? { |t| t[:ready] } + end + + def use_same_connection_handler_for_all_theads + ActiveRecord::Base.default_connection_handler = ActiveRecord::Base.connection_handler + end + + def record_name(db_result) + name_column_index = 1 + db_result.first[name_column_index] + end + + def execute_sql(query) + ActiveRecord::Base.connection.execute(query) + end + + def create_seed_data + ActiveRecord::Base.on_primary_db do + Account.connection.execute(account_insert_sql(name: "Primary account")) + + Account.on_replica do + Account.connection.execute(account_insert_sql(name: "Replica account")) + end + end + + [0, 1].each do |shard_id| + ActiveRecord::Base.on_shard(shard_id) do + Ticket.connection.execute(ticket_insert_sql(title: "Shard #{shard_id} Primary ticket")) + + Ticket.on_replica do + Ticket.connection.execute(ticket_insert_sql(title: "Shard #{shard_id} Replica ticket")) + end + end + end + end + + def account_insert_sql(name:) + "INSERT INTO accounts (id, name, created_at, updated_at)" \ + " VALUES (1000, '#{name}', NOW(), NOW())" + end + + def ticket_insert_sql(title:) + "INSERT INTO tickets (id, title, account_id, created_at, updated_at)" \ + " VALUES (1000, '#{title}', 5000, NOW(), NOW())" + end +end