diff --git a/integration_test/cases/execute_test.exs b/integration_test/cases/execute_test.exs index f00ac84..9f46a8f 100644 --- a/integration_test/cases/execute_test.exs +++ b/integration_test/cases/execute_test.exs @@ -348,6 +348,66 @@ defmodule ExecuteTest do ] = A.record(agent) end + test "execute disconnect_and_retry succeeds" do + err = RuntimeError.exception("oops") + + stack = [ + {:ok, :state}, + {:disconnect_and_retry, err, :state}, + :ok, + fn opts -> + send(opts[:parent], :reconnected) + {:ok, :new_state} + end, + {:ok, %Q{}, %R{}, :new_state} + ] + + {:ok, agent} = A.start_link(stack) + + opts = [agent: agent, parent: self()] + {:ok, pool} = P.start_link(opts) + assert P.execute(pool, %Q{}, [:param]) == {:ok, %Q{}, %R{}} + + assert_receive :reconnected + + assert [ + connect: [opts2], + handle_execute: [%Q{}, [:param], _, :state], + disconnect: [^err, :state], + connect: [opts2], + handle_execute: [%Q{}, [:param], _, :new_state] + ] = A.record(agent) + end + + test "execute disconnect_and_retry errors if there are no retries" do + err = RuntimeError.exception("oops") + + stack = [ + {:ok, :state}, + {:disconnect_and_retry, err, :new_state}, + :ok, + fn opts -> + send(opts[:parent], :reconnected) + {:ok, :state} + end + ] + + {:ok, agent} = A.start_link(stack) + + opts = [agent: agent, parent: self()] + {:ok, pool} = P.start_link(opts) + assert P.execute(pool, %Q{}, [:param], checkout_retries: 0) == {:error, err} + + assert_receive :reconnected + + assert [ + connect: [opts2], + handle_execute: [%Q{}, [:param], _, :state], + disconnect: [^err, :new_state], + connect: [opts2] + ] = A.record(agent) + end + test "execute bad return raises DBConnection.ConnectionError and stops" do stack = [ fn opts -> diff --git a/integration_test/cases/prepare_execute_test.exs b/integration_test/cases/prepare_execute_test.exs index f397bbc..0adebe1 100644 --- a/integration_test/cases/prepare_execute_test.exs +++ b/integration_test/cases/prepare_execute_test.exs @@ -496,6 +496,39 @@ defmodule PrepareExecuteTest do ] = A.record(agent) end + test "prepare_execute execute disconnect_and_retry succeeds" do + err = RuntimeError.exception("oops") + + stack = [ + {:ok, :state}, + {:disconnect_and_retry, err, :state}, + :ok, + fn opts -> + send(opts[:parent], :reconnected) + {:ok, :new_state} + end, + {:ok, %Q{}, :newer_state}, + {:ok, %Q{state: :executed}, %R{}, :newest_state} + ] + + {:ok, agent} = A.start_link(stack) + + opts = [agent: agent, parent: self()] + {:ok, pool} = P.start_link(opts) + assert P.prepare_execute(pool, %Q{}, [:param]) == {:ok, %Q{state: :executed}, %R{}} + + assert_receive :reconnected + + assert [ + connect: [opts2], + handle_prepare: [%Q{}, _, :state], + disconnect: [^err, :state], + connect: [opts2], + handle_prepare: [%Q{}, _, :new_state], + handle_execute: [%Q{}, [:param], _, :newer_state] + ] = A.record(agent) + end + test "prepare_execute describe or encode raises and closes query" do stack = [ {:ok, :state}, diff --git a/integration_test/cases/transaction_test.exs b/integration_test/cases/transaction_test.exs index 8b0bbe5..8963115 100644 --- a/integration_test/cases/transaction_test.exs +++ b/integration_test/cases/transaction_test.exs @@ -435,6 +435,70 @@ defmodule TransactionTest do ] = A.record(agent) end + test "transaction begin disconnect_and_retry succeeds" do + err = RuntimeError.exception("oops") + + stack = [ + {:ok, :state}, + {:disconnect_and_retry, err, :state}, + :ok, + fn opts -> + send(opts[:parent], :reconnected) + {:ok, :new_state} + end, + {:ok, :began, :newer_state}, + {:ok, :committed, :newest_state} + ] + + {:ok, agent} = A.start_link(stack) + + opts = [agent: agent, parent: self()] + {:ok, pool} = P.start_link(opts) + assert P.transaction(pool, fn _ -> :ok end) == {:ok, :ok} + assert_receive :reconnected + + assert [ + connect: [_], + handle_begin: [_, :state], + disconnect: [_, :state], + connect: [_], + handle_begin: [_, :new_state], + handle_commit: [_, :newer_state] + ] = A.record(agent) + end + + test "transaction begin disconnect_and_retry errors if there are no retries" do + err = RuntimeError.exception("oops") + + stack = [ + {:ok, :state}, + {:disconnect_and_retry, err, :new_state}, + :ok, + fn opts -> + send(opts[:parent], :reconnected) + {:ok, :newest_state} + end + ] + + {:ok, agent} = A.start_link(stack) + + opts = [agent: agent, parent: self()] + {:ok, pool} = P.start_link(opts) + + assert_raise RuntimeError, "oops", fn -> + P.transaction(pool, fn _ -> flunk("transaction ran") end, checkout_retries: 0) + end + + assert_receive :reconnected + + assert [ + connect: [_], + handle_begin: [_, :state], + disconnect: [_, :new_state], + connect: [_] + ] = A.record(agent) + end + test "transaction begin bad return raises and stops connection" do stack = [ fn opts -> diff --git a/lib/db_connection.ex b/lib/db_connection.ex index 92f56ab..27cbe53 100644 --- a/lib/db_connection.ex +++ b/lib/db_connection.ex @@ -203,10 +203,10 @@ defmodule DBConnection do Return `{:ok, result, state}`/`{:ok, query, result, state}` to continue, `{status, state}` to notify caller that the transaction can not begin due - to the transaction status `status`, or `{:disconnect, exception, state}` - to error and disconnect. If `{:ok, query, result, state}` is returned, - the query will be used to log the begin command. Otherwise, it will be - logged as `begin`. + to the transaction status `status`, or `{:disconnect | :disconnect_and_retry, exception, state}` + to error and disconnect (and optionally retry). If `{:ok, query, result, state}` + is returned, the query will be used to log the begin command. Otherwise, + it will be logged as `begin`. A callback implementation should only return `status` if it can determine the database's transaction status without side effect. @@ -217,7 +217,7 @@ defmodule DBConnection do {:ok, result, new_state :: any} | {:ok, query, result, new_state :: any} | {status, new_state :: any} - | {:disconnect, Exception.t(), new_state :: any} + | {:disconnect | :disconnect_and_retry, Exception.t(), new_state :: any} @doc """ Handle committing a transaction. Return `{:ok, result, state}` on successfully @@ -255,20 +255,22 @@ defmodule DBConnection do Handle getting the transaction status. Return `{:idle, state}` if outside a transaction, `{:transaction, state}` if inside a transaction, `{:error, state}` if inside an aborted transaction, or - `{:disconnect, exception, state}` to error and disconnect. + `{:disconnect | :disconnect_and_retry, exception, state}` to error and disconnect + (and optionally retry). If the callback returns a `:disconnect` tuples then `status/2` will return `:error`. """ @callback handle_status(opts :: Keyword.t(), state :: any) :: {status, new_state :: any} - | {:disconnect, Exception.t(), new_state :: any} + | {:disconnect | :disconnect_and_retry, Exception.t(), new_state :: any} @doc """ Prepare a query with the database. Return `{:ok, query, state}` where `query` is a query to pass to `execute/4` or `close/3`, `{:error, exception, state}` to return an error and continue or - `{:disconnect, exception, state}` to return an error and disconnect. + `{:disconnect | :disconnect_and_retry, exception, state}` to error and disconnect + (and optionally retry). This callback is intended for cases where the state of a connection is needed to prepare a query and/or the query can be saved in the @@ -278,45 +280,46 @@ defmodule DBConnection do """ @callback handle_prepare(query, opts :: Keyword.t(), state :: any) :: {:ok, query, new_state :: any} - | {:error | :disconnect, Exception.t(), new_state :: any} + | {:error | :disconnect | :disconnect_and_retry, Exception.t(), new_state :: any} @doc """ Execute a query prepared by `c:handle_prepare/3`. Return `{:ok, query, result, state}` to return altered query `query` and result `result` and continue, `{:error, exception, state}` to return an error and - continue or `{:disconnect, exception, state}` to return an error and - disconnect. + continue or `{:disconnect | :disconnect_and_retry, exception, state}` to + error and disconnect (and optionally retry). This callback is called in the client process. """ @callback handle_execute(query, params, opts :: Keyword.t(), state :: any) :: {:ok, query, result, new_state :: any} - | {:error | :disconnect, Exception.t(), new_state :: any} + | {:error | :disconnect | :disconnect_and_retry, Exception.t(), new_state :: any} @doc """ Close a query prepared by `c:handle_prepare/3` with the database. Return `{:ok, result, state}` on success and to continue, `{:error, exception, state}` to return an error and continue, or - `{:disconnect, exception, state}` to return an error and disconnect. + `{:disconnect | :disconnect_and_retry, exception, state}` to + error and disconnect (and optionally retry). This callback is called in the client process. """ @callback handle_close(query, opts :: Keyword.t(), state :: any) :: {:ok, result, new_state :: any} - | {:error | :disconnect, Exception.t(), new_state :: any} + | {:error | :disconnect | :disconnect_and_retry, Exception.t(), new_state :: any} @doc """ Declare a cursor using a query prepared by `c:handle_prepare/3`. Return `{:ok, query, cursor, state}` to return altered query `query` and cursor `cursor` for a stream and continue, `{:error, exception, state}` to return an - error and continue or `{:disconnect, exception, state}` to return an error - and disconnect. + error and continue or `{:disconnect | :disconnect_and_retry, exception, state}` + to error and disconnect (and optionally retry). This callback is called in the client process. """ @callback handle_declare(query, params, opts :: Keyword.t(), state :: any) :: {:ok, query, cursor, new_state :: any} - | {:error | :disconnect, Exception.t(), new_state :: any} + | {:error | :disconnect | :disconnect_and_retry, Exception.t(), new_state :: any} @doc """ Fetch the next result from a cursor declared by `c:handle_declare/4`. Return @@ -358,11 +361,11 @@ defmodule DBConnection do The last known state will be sent and the exception will be a `DBConnection.ConnectionError` containing the reason for the exit. To have the same happen on unexpected shutdowns, you may trap exits from the `connect` callback. - """ @callback disconnect(err :: Exception.t(), state :: any) :: :ok @connection_module_key :connection_module + @checkout_retries 3 @doc """ Use `DBConnection` to set the behaviour. @@ -382,42 +385,61 @@ defmodule DBConnection do ### Options + * `:after_connect` - A function to run on connect using `run/3`, either + a 1-arity fun, `{module, function, args}` with `t:DBConnection.t/0` prepended + to `args` or `nil` (default: `nil`) + + * `:after_connect_timeout` - The maximum time allowed to perform + function specified by `:after_connect` option (default: `15_000`) + * `:backoff_min` - The minimum backoff interval (default: `1_000`) + * `:backoff_max` - The maximum backoff interval (default: `30_000`) + * `:backoff_type` - The backoff strategy, `:stop` for no backoff and - to stop, `:exp` for exponential, `:rand` for random and `:rand_exp` for - random exponential (default: `:rand_exp`) + to stop, `:exp` for exponential, `:rand` for random and `:rand_exp` for + random exponential (default: `:rand_exp`) + + * `:checkout_retries` - The number of times to checkout a new connection + whenever the operation fails because the database disconnected. Note + not all operations can be retried and each adapter specifies which + operations are safe to retry + * `:configure` - A function to run before every connect attempt to - dynamically configure the options, either a 1-arity fun, - `{module, function, args}` or `nil`. This function is called - *in the connection process*. For more details, see - [Connection Configuration Callback](#start_link/2-connection-configuration-callback) - * `:after_connect` - A function to run on connect using `run/3`, either - a 1-arity fun, `{module, function, args}` with `t:DBConnection.t/0` prepended - to `args` or `nil` (default: `nil`) - * `:after_connect_timeout` - The maximum time allowed to perform - function specified by `:after_connect` option (default: `15_000`) + dynamically configure the options, either a 1-arity fun, + `{module, function, args}` or `nil`. This function is called *in the + connection process*. For more details, see + [Connection Configuration Callback](#start_link/2-connection-configuration-callback) + * `:connection_listeners` - A list of process destinations to send notification messages whenever a connection is connected or disconnected. See "Connection listeners" below - * `:name` - A name to register the started process (see the `:name` option - in `GenServer.start_link/3`) - * `:pool` - Chooses the pool to be started (default: `DBConnection.ConnectionPool`). See - ["Connection pools"](#module-connection-pools). - * `:pool_size` - Chooses the size of the pool. Must be greater or equal to 1. (default: `1`) + * `:idle_interval` - Controls the frequency we check for idle connections in the pool. We then notify each idle connection to ping the database. In practice, the ping happens within `idle_interval <= ping < 2 * idle_interval`. Defaults to 1000ms. + * `:idle_limit` - The number of connections to ping on each `:idle_interval`. Defaults to the pool size (all connections). - * `:queue_target` and `:queue_interval` - See "Queue config" below + * `:max_restarts` and `:max_seconds` - Configures the `:max_restarts` and `:max_seconds` for the connection pool supervisor (see the `Supervisor` docs). Typically speaking the connection process doesn't terminate, except due to faults in DBConnection. However, if backoff has been disabled, then they also terminate whenever a connection is disconnected (for instance, due to client or server errors) + + * `:name` - A name to register the started process (see the `:name` option + in `GenServer.start_link/3`) + + * `:pool` - Chooses the pool to be started (default: `DBConnection.ConnectionPool`). + See ["Connection pools"](#module-connection-pools). + + * `:pool_size` - Chooses the size of the pool. Must be greater or equal to 1. (default: `1`) + + * `:queue_target` and `:queue_interval` - See "Queue config" below + * `:show_sensitive_data_on_connection_error` - By default, `DBConnection` hides all information during connection errors to avoid leaking credentials or other sensitive information. You can set this option if you wish to @@ -952,8 +974,15 @@ defmodule DBConnection do {:ok, conn, old_status, _} -> try do result = fun.(conn) - {:ok, new_status, _meter} = run_status(conn, nil, opts) - {result, new_status} + + case run_status(conn, nil, opts) do + {:ok, new_status, _meter} -> + {result, new_status} + + {:retry, err, _meter} -> + disconnect(conn, err) + {result, :error} + end catch kind, error -> checkin(conn) @@ -1323,12 +1352,6 @@ defmodule DBConnection do Holder.checkin(pool_ref) end - defp checkin(%DBConnection{} = conn, fun, meter, opts) do - return = fun.(conn, meter, opts) - checkin(conn) - return - end - defp disconnect(%DBConnection{pool_ref: pool_ref}, err) do _ = Holder.disconnect(pool_ref, err) :ok @@ -1341,6 +1364,17 @@ defmodule DBConnection do :ok end + defp retry_or_handle_common_result(return, conn, meter) do + case return do + {:disconnect_and_retry, err, _conn_state} -> + disconnect(conn, err) + {:retry, err, meter} + + _ -> + handle_common_result(return, conn, meter) + end + end + defp handle_common_result(return, conn, meter) do case return do {:ok, result, _conn_state} -> @@ -1488,7 +1522,7 @@ defmodule DBConnection do defp prepare(%DBConnection{pool_ref: pool_ref} = conn, query, meter, opts) do pool_ref |> Holder.handle(:handle_prepare, [query], opts) - |> handle_common_result(conn, event(meter, :prepare)) + |> retry_or_handle_common_result(conn, event(meter, :prepare)) end defp run_prepare_execute(conn, query, params, meter, opts) do @@ -1510,7 +1544,7 @@ defmodule DBConnection do bad_return!(other, conn, meter) other -> - handle_common_result(other, conn, meter) + retry_or_handle_common_result(other, conn, meter) end end @@ -1522,45 +1556,85 @@ defmodule DBConnection do defp run_close(conn, query, meter, opts) do meter = event(meter, :close) - run_cleanup(conn, :handle_close, [query], meter, opts) + + cleanup(conn, :handle_close, [query], opts) + |> retry_or_handle_common_result(conn, meter) end - defp run_cleanup(conn, fun, args, meter, opts) do + defp cleanup(conn, fun, args, opts) do %DBConnection{pool_ref: pool_ref} = conn - Holder.cleanup(pool_ref, fun, args, opts) - |> handle_common_result(conn, meter) end # run/4 and checkout/4 are the two entry points to get a connection. # run returns only the result, checkout also returns the connection. defp run(%DBConnection{} = conn, fun, meter, opts) do - fun.(conn, meter, opts) + with {:retry, err, meter} <- fun.(conn, meter, opts) do + {:error, err, meter} + end end defp run(pool, fun, meter, opts) do + retries = Keyword.get(opts, :checkout_retries, @checkout_retries) + run_with_retries(retries, pool, fun, meter, opts) + end + + defp run_with_retries(retries, pool, fun, meter, opts) do with {:ok, conn, meter} <- checkout(pool, meter, opts) do - try do - fun.(conn, meter, opts) - after - checkin(conn) + result = + try do + fun.(conn, meter, opts) + after + checkin(conn) + end + + case result do + {:retry, _err, meter} when retries > 0 -> + run_with_retries(retries - 1, pool, fun, meter, opts) + + {:retry, err, meter} -> + {:error, err, meter} + + other -> + other end end end defp checkout(%DBConnection{} = conn, fun, meter, opts) do - with {:ok, result, meter} <- fun.(conn, meter, opts) do - {:ok, conn, result, meter} + case fun.(conn, meter, opts) do + {:ok, result, meter} -> + {:ok, conn, result, meter} + + {:retry, err, meter} -> + {:error, err, meter} + + other -> + other end end defp checkout(pool, fun, meter, opts) do + retries = Keyword.get(opts, :checkout_retries, @checkout_retries) + checkout_with_retries(retries, pool, fun, meter, opts) + end + + defp checkout_with_retries(retries, pool, fun, meter, opts) do with {:ok, conn, meter} <- checkout(pool, meter, opts) do case fun.(conn, meter, opts) do {:ok, result, meter} -> {:ok, conn, result, meter} + {:retry, err, meter} -> + checkin(conn) + + if retries > 0 do + checkout_with_retries(retries - 1, pool, fun, meter, opts) + else + {:error, err, meter} + end + error -> checkin(conn) error @@ -1568,6 +1642,12 @@ defmodule DBConnection do end end + defp checkin(%DBConnection{} = conn, fun, meter, opts) do + return = fun.(conn, meter, opts) + checkin(conn) + return + end + defp meter(opts) do case Keyword.get(opts, :log) do nil -> nil @@ -1754,7 +1834,7 @@ defmodule DBConnection do {:ok, {query, result}, meter} other -> - handle_common_result(other, conn, meter) + retry_or_handle_common_result(other, conn, meter) end end @@ -1784,14 +1864,11 @@ defmodule DBConnection do err = DBConnection.TransactionError.exception(:error) {:error, err} - {query, other} -> - log(other, :commit, query, nil) + {:rollback, other} -> + log(other, :commit, :rollback, nil) - {:error, err, meter} -> - log(meter, :commit, :commit, nil, {:error, err}) - - {kind, reason, stack, meter} -> - log(meter, :commit, :commit, nil, {kind, reason, stack}) + other -> + log(other, :commit, :commit, nil) end end @@ -1804,10 +1881,10 @@ defmodule DBConnection do {:rollback, run_rollback(conn, meter, opts)} {status, _conn_state} when status in [:idle, :transaction] -> - {:commit, status_disconnect(conn, status, meter)} + status_disconnect(conn, status, meter) other -> - {:commit, handle_common_result(other, conn, meter)} + handle_common_result(other, conn, meter) end end @@ -1820,20 +1897,18 @@ defmodule DBConnection do defp run_status(conn, meter, opts) do %DBConnection{pool_ref: pool_ref} = conn + # status queries are not logged, which means we need to deal + # with catch and disconnections explicitly case Holder.handle(pool_ref, :handle_status, [], opts) do {status, _conn_state} when status in [:idle, :transaction, :error] -> {:ok, status, meter} - {:disconnect, err, _conn_state} -> - disconnect(conn, err) - {:ok, :error, meter} - - {:catch, kind, reason, stack} -> - stop(conn, kind, reason, stack) - :erlang.raise(kind, reason, stack) - other -> - bad_return!(other, conn, meter) + case retry_or_handle_common_result(other, conn, meter) do + {:error, _, meter} -> {:ok, :error, meter} + {kind, reason, stack, _meter} -> :erlang.raise(kind, reason, stack) + _ -> other + end end end @@ -1858,13 +1933,13 @@ defmodule DBConnection do bad_return!(other, conn, meter) other -> - handle_common_result(other, conn, meter) + retry_or_handle_common_result(other, conn, meter) end end defp stream_fetch(%DBConnection{} = conn, {:cont, query, cursor}, opts) do with {ok, result, meter} when ok in [:cont, :halt] <- - run_fetch(conn, [query, cursor], meter(opts), opts), + fetch(conn, [query, cursor], meter(opts), opts), {:ok, result, meter} <- decode(query, result, meter, opts) do {ok, result, meter} end @@ -1882,7 +1957,7 @@ defmodule DBConnection do {:halt, state} end - defp run_fetch(conn, args, meter, opts) do + defp fetch(conn, args, meter, opts) do %DBConnection{pool_ref: pool_ref} = conn meter = event(meter, :fetch) @@ -1902,7 +1977,8 @@ defmodule DBConnection do meter = event(meter(opts), :deallocate) conn - |> run_cleanup(:handle_deallocate, [query, cursor], meter, opts) + |> cleanup(:handle_deallocate, [query, cursor], opts) + |> handle_common_result(conn, meter) |> log(:deallocate, query, cursor) end diff --git a/lib/db_connection/ownership/proxy.ex b/lib/db_connection/ownership/proxy.ex index 1426c17..fa27fb0 100644 --- a/lib/db_connection/ownership/proxy.ex +++ b/lib/db_connection/ownership/proxy.ex @@ -61,7 +61,7 @@ defmodule DBConnection.Ownership.Proxy do @impl true def handle_info({:DOWN, ref, _, pid, _reason}, %{owner: {_, ref}} = state) do - down("owner #{Util.inspect_pid(pid)} exited", state) + shutdown("owner #{Util.inspect_pid(pid)} exited", state) end def handle_info({:timeout, deadline, {_ref, holder, pid, len}}, %{holder: holder} = state) do @@ -70,7 +70,7 @@ defmodule DBConnection.Ownership.Proxy do "client #{Util.inspect_pid(pid)} timed out because " <> "it queued and checked out the connection for longer than #{len}ms" - down(message, state) + shutdown(message, state) else {:noreply, state} end @@ -84,9 +84,9 @@ defmodule DBConnection.Ownership.Proxy do "owner #{Util.inspect_pid(pid)} timed out because " <> "it owned the connection for longer than #{timeout}ms (set via the :ownership_timeout option)" - # We don't invoke down because this is always a disconnect, even if there is no client. + # We don't invoke shutdown because this is always a disconnect, even if there is no client. # On the other hand, those timeouts are unlikely to trigger, as it defaults to 2 mins. - pool_disconnect(DBConnection.ConnectionError.exception(message), state) + pool_disconnect(DBConnection.ConnectionError.exception(message), false, state) end def handle_info({:timeout, poll, time}, %{poll: poll} = state) do @@ -147,13 +147,13 @@ defmodule DBConnection.Ownership.Proxy do ) do case msg do :checkin -> checkin(state) - :disconnect -> pool_disconnect(extra, state) + :disconnect -> pool_disconnect(extra, true, state) :stop -> pool_stop(extra, state) end end def handle_info({:"ETS-TRANSFER", holder, pid, ref}, %{holder: holder, owner: {_, ref}} = state) do - down("client #{Util.inspect_pid(pid)} exited", state) + shutdown("client #{Util.inspect_pid(pid)} exited", state) end @impl true @@ -171,7 +171,7 @@ defmodule DBConnection.Ownership.Proxy do Exception.format_stacktrace(current_stack) end - down(message, state) + shutdown(message, state) end @impl true @@ -222,18 +222,13 @@ defmodule DBConnection.Ownership.Proxy do :erlang.start_timer(timeout, self(), {__MODULE__, pid, timeout}) end - # It is down but never checked out from pool - defp down(reason, %{holder: nil} = state) do - {:stop, {:shutdown, reason}, state} - end - - # If it is down but it has no client, checkin - defp down(reason, %{client: nil} = state) do + # If shutting down but it has no client, checkin + defp shutdown(reason, %{client: nil} = state) do pool_checkin(reason, state) end - # If it is down but it has a client, disconnect - defp down(reason, %{client: {client, _, checkout_stack}} = state) do + # If shutting down but it has a client, disconnect + defp shutdown(reason, %{client: {client, _, checkout_stack}} = state) do reason = case pruned_stacktrace(client) do [] -> @@ -252,24 +247,27 @@ defmodule DBConnection.Ownership.Proxy do end err = DBConnection.ConnectionError.exception(reason) - pool_disconnect(err, state) + pool_disconnect(err, false, state) end ## Helpers defp pool_checkin(reason, state) do - pool_done(reason, state, :checkin, fn pool_ref, _ -> Holder.checkin(pool_ref) end) + checkin = fn pool_ref, _ -> Holder.checkin(pool_ref) end + pool_done(reason, state, :checkin, false, checkin, &Holder.disconnect/2) end - defp pool_disconnect(err, state) do - pool_done(err, state, {:disconnect, err}, &Holder.disconnect/2) + defp pool_disconnect(err, keep_alive?, state) do + disconnect = &Holder.disconnect/2 + pool_done(err, state, {:disconnect, err}, keep_alive?, disconnect, disconnect) end defp pool_stop(err, state) do - pool_done(err, state, {:stop, err}, &Holder.stop/2, &Holder.stop/2) + stop = &Holder.stop/2 + pool_done(err, state, {:stop, err}, false, stop, stop) end - defp pool_done(err, state, op, done, stop_or_disconnect \\ &Holder.disconnect/2) do + defp pool_done(err, state, op, keep_alive?, done, stop_or_disconnect) do %{holder: holder, pool_ref: pool_ref, pre_checkin: pre_checkin, mod: original_mod} = state if holder do @@ -279,7 +277,12 @@ defmodule DBConnection.Ownership.Proxy do {:ok, ^original_mod, conn_state} -> Holder.put_state(pool_ref, conn_state) done.(pool_ref, err) - {:stop, {:shutdown, err}, state} + + if keep_alive? do + {:noreply, %{state | holder: nil}} + else + {:stop, {:shutdown, err}, state} + end {:disconnect, err, ^original_mod, conn_state} -> Holder.put_state(pool_ref, conn_state)