Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 33 additions & 6 deletions rust/core/src/sync.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,24 @@ pub trait Optionable {
fn get_option_double(&self, key: Self::Option) -> Result<f64>;
}

/// A handle to cancel an in-progress operation on a connection.
///
/// This is a separated handle because otherwise it would be impossible to
/// call a `cancel` method on a connection or statement itself.
pub trait CancelHandle: Send {
/// Cancel the in-progress operation on a connection.
fn try_cancel(&self) -> Result<()>;
}

/// A cancellation handle that does nothing (because cancellation is unsupported).
pub struct NoOpCancellationHandle;

impl CancelHandle for NoOpCancellationHandle {
fn try_cancel(&self) -> Result<()> {
Ok(())
}
}

/// A handle to an ADBC driver.
pub trait Driver {
type DatabaseType: Database;
Expand Down Expand Up @@ -76,6 +94,11 @@ pub trait Database: Optionable<Option = OptionDatabase> {
&self,
opts: impl IntoIterator<Item = (options::OptionConnection, OptionValue)>,
) -> Result<Self::ConnectionType>;

/// Get a handle to cancel operations on this database.
fn get_cancel_handle(&self) -> Box<dyn CancelHandle> {
Box::new(NoOpCancellationHandle {})
}
}

/// A handle to an ADBC connection.
Expand All @@ -94,8 +117,10 @@ pub trait Connection: Optionable<Option = OptionConnection> {
/// Allocate and initialize a new statement.
fn new_statement(&mut self) -> Result<Self::StatementType>;

/// Cancel the in-progress operation on a connection.
fn cancel(&mut self) -> Result<()>;
/// Get a handle to cancel operations on this connection.
fn get_cancel_handle(&self) -> Box<dyn CancelHandle> {
Box::new(NoOpCancellationHandle {})
}

/// Get metadata about the database/driver.
///
Expand Down Expand Up @@ -455,13 +480,15 @@ pub trait Statement: Optionable<Option = OptionStatement> {
/// expected to be executed repeatedly, call [Statement::prepare] first.
fn set_substrait_plan(&mut self, plan: impl AsRef<[u8]>) -> Result<()>;

/// Cancel execution of an in-progress query.
/// Get a handle to cancel operations on this statement.
///
/// This can be called during [Statement::execute] (or similar), or while
/// consuming a result set returned from such.
/// The resulting handle can be called during [Statement::execute] (or
/// similar), or while consuming a result set returned from such.
///
/// # Since
///
/// ADBC API revision 1.1.0
fn cancel(&mut self) -> Result<()>;
fn get_cancel_handle(&self) -> Box<dyn CancelHandle> {
Box::new(NoOpCancellationHandle {})
}
}
8 changes: 0 additions & 8 deletions rust/driver/datafusion/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -735,10 +735,6 @@ impl Connection for DataFusionConnection {
})
}

fn cancel(&mut self) -> adbc_core::error::Result<()> {
todo!()
}

fn get_info(
&self,
codes: Option<std::collections::HashSet<adbc_core::options::InfoCode>>,
Expand Down Expand Up @@ -984,10 +980,6 @@ impl Statement for DataFusionStatement {
self.substrait_plan = Some(Plan::decode(plan.as_ref()).unwrap());
Ok(())
}

fn cancel(&mut self) -> adbc_core::error::Result<()> {
todo!()
}
}

#[cfg(feature = "ffi")]
Expand Down
32 changes: 18 additions & 14 deletions rust/driver/dummy/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -294,16 +294,24 @@ impl Connection for DummyConnection {
Ok(Self::StatementType::default())
}

// This method is used to test that errors round-trip correctly.
fn cancel(&mut self) -> Result<()> {
let mut error = Error::with_message_and_status("message", Status::Cancelled);
error.vendor_code = constants::ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA;
error.sqlstate = [1, 2, 3, 4, 5];
error.details = Some(vec![
("key1".into(), b"AAA".into()),
("key2".into(), b"ZZZZZ".into()),
]);
Err(error)
/// This method is used to test that errors round-trip correctly.
fn get_cancel_handle(&self) -> Box<dyn adbc_core::CancelHandle> {
struct CancelHandle;

impl adbc_core::CancelHandle for CancelHandle {
fn try_cancel(&self) -> Result<()> {
let mut error = Error::with_message_and_status("message", Status::Cancelled);
error.vendor_code = constants::ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA;
error.sqlstate = [1, 2, 3, 4, 5];
error.details = Some(vec![
("key1".into(), b"AAA".into()),
("key2".into(), b"ZZZZZ".into()),
]);
Err(error)
}
}

Box::new(CancelHandle)
}

fn commit(&mut self) -> Result<()> {
Expand Down Expand Up @@ -854,10 +862,6 @@ impl Statement for DummyStatement {
Ok(())
}

fn cancel(&mut self) -> Result<()> {
Ok(())
}

fn execute(&mut self) -> Result<Box<dyn RecordBatchReader + Send + 'static>> {
maybe_panic("StatementExecuteQuery");
let batch = get_table_data();
Expand Down
22 changes: 14 additions & 8 deletions rust/driver/dummy/tests/driver_exporter_dummy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -366,11 +366,14 @@ fn test_connection_get_info() {

#[test]
fn test_connection_cancel() {
let (_, _, mut exported_connection, _) = get_exported();
let (_, _, mut native_connection, _) = get_native();
let (_, _, exported_connection, _) = get_exported();
let (_, _, native_connection, _) = get_native();

let exported_handle = exported_connection.get_cancel_handle();
let native_handle = native_connection.get_cancel_handle();

let exported_error = exported_connection.cancel().unwrap_err();
let native_error = native_connection.cancel().unwrap_err();
let exported_error = exported_handle.try_cancel().unwrap_err();
let native_error = native_handle.try_cancel().unwrap_err();

assert_eq!(exported_error, native_error);
}
Expand Down Expand Up @@ -569,11 +572,14 @@ fn test_statement_bind_stream() {

#[test]
fn test_statement_cancel() {
let (_, _, _, mut exported_statement) = get_exported();
let (_, _, _, mut native_statement) = get_native();
let (_, _, _, exported_statement) = get_exported();
let (_, _, _, native_statement) = get_native();

let exported_handle = exported_statement.get_cancel_handle();
let native_handle = native_statement.get_cancel_handle();

exported_statement.cancel().unwrap();
native_statement.cancel().unwrap();
exported_handle.try_cancel().unwrap();
native_handle.try_cancel().unwrap();
}

#[test]
Expand Down
4 changes: 2 additions & 2 deletions rust/driver/snowflake/src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ impl adbc_core::Connection for Connection {
self.0.new_statement().map(Statement)
}

fn cancel(&mut self) -> Result<()> {
self.0.cancel()
fn get_cancel_handle(&self) -> Box<dyn adbc_core::CancelHandle> {
self.0.get_cancel_handle()
}

fn get_info(
Expand Down
4 changes: 2 additions & 2 deletions rust/driver/snowflake/src/statement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ impl adbc_core::Statement for Statement {
self.0.set_substrait_plan(plan)
}

fn cancel(&mut self) -> Result<()> {
self.0.cancel()
fn get_cancel_handle(&self) -> Box<dyn adbc_core::CancelHandle> {
self.0.get_cancel_handle()
}
}
84 changes: 58 additions & 26 deletions rust/driver_manager/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1027,6 +1027,31 @@ pub struct ManagedConnection {
inner: Arc<ManagedConnectionInner>,
}

struct ConnectionCancelHandle {
inner: std::sync::Weak<ManagedConnectionInner>,
}

impl adbc_core::CancelHandle for ConnectionCancelHandle {
fn try_cancel(&self) -> Result<()> {
if let Some(inner) = self.inner.upgrade() {
if let AdbcVersion::V100 = inner.database.driver.version {
return Err(Error::with_message_and_status(
ERR_CANCEL_UNSUPPORTED,
Status::NotImplemented,
));
}
let driver = &inner.database.driver.driver;
let mut connection = inner.connection.lock().unwrap();
let mut error = adbc_ffi::FFI_AdbcError::with_driver(driver);
let method = driver_method!(driver, ConnectionCancel);
let status = unsafe { method(connection.deref_mut(), &mut error) };
check_status(status, error)
} else {
Ok(())
}
}
}

impl ManagedConnection {
fn ffi_driver(&self) -> &adbc_ffi::FFI_AdbcDriver {
&self.inner.database.driver.driver
Expand Down Expand Up @@ -1125,19 +1150,10 @@ impl Connection for ManagedConnection {
Ok(Self::StatementType { inner })
}

fn cancel(&mut self) -> Result<()> {
if let AdbcVersion::V100 = self.driver_version() {
return Err(Error::with_message_and_status(
ERR_CANCEL_UNSUPPORTED,
Status::NotImplemented,
));
}
let driver = self.ffi_driver();
let mut connection = self.inner.connection.lock().unwrap();
let mut error = adbc_ffi::FFI_AdbcError::with_driver(driver);
let method = driver_method!(driver, ConnectionCancel);
let status = unsafe { method(connection.deref_mut(), &mut error) };
check_status(status, error)
fn get_cancel_handle(&self) -> Box<dyn adbc_core::CancelHandle> {
Box::new(ConnectionCancelHandle {
inner: Arc::downgrade(&self.inner),
})
}

fn commit(&mut self) -> Result<()> {
Expand Down Expand Up @@ -1401,6 +1417,31 @@ impl ManagedStatement {
}
}

struct StatementCancelHandle {
inner: std::sync::Weak<ManagedStatementInner>,
}

impl adbc_core::CancelHandle for StatementCancelHandle {
fn try_cancel(&self) -> Result<()> {
if let Some(inner) = self.inner.upgrade() {
if let AdbcVersion::V100 = inner.connection.database.driver.version {
return Err(Error::with_message_and_status(
ERR_CANCEL_UNSUPPORTED,
Status::NotImplemented,
));
}
let driver = &inner.connection.database.driver.driver;
let mut statement = inner.statement.lock().unwrap();
let mut error = adbc_ffi::FFI_AdbcError::with_driver(driver);
let method = driver_method!(driver, StatementCancel);
let status = unsafe { method(statement.deref_mut(), &mut error) };
check_status(status, error)
} else {
Ok(())
}
}
}

impl Statement for ManagedStatement {
fn bind(&mut self, batch: RecordBatch) -> Result<()> {
let driver = self.ffi_driver();
Expand All @@ -1425,19 +1466,10 @@ impl Statement for ManagedStatement {
Ok(())
}

fn cancel(&mut self) -> Result<()> {
if let AdbcVersion::V100 = self.driver_version() {
return Err(Error::with_message_and_status(
ERR_CANCEL_UNSUPPORTED,
Status::NotImplemented,
));
}
let driver = self.ffi_driver();
let mut statement = self.inner.statement.lock().unwrap();
let mut error = adbc_ffi::FFI_AdbcError::with_driver(driver);
let method = driver_method!(driver, StatementCancel);
let status = unsafe { method(statement.deref_mut(), &mut error) };
check_status(status, error)
fn get_cancel_handle(&self) -> Box<dyn adbc_core::CancelHandle> {
Box::new(StatementCancelHandle {
inner: Arc::downgrade(&self.inner),
})
}

fn execute(&mut self) -> Result<Box<dyn RecordBatchReader + Send + 'static>> {
Expand Down
10 changes: 6 additions & 4 deletions rust/driver_manager/tests/driver_manager_sqlite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,10 @@ fn test_connection_get_option() {
fn test_connection_cancel() {
let mut driver = get_driver();
let database = get_database(&mut driver);
let mut connection = database.new_connection().unwrap();
let connection = database.new_connection().unwrap();

let error = connection.cancel().unwrap_err();
let handle = connection.get_cancel_handle();
let error = handle.try_cancel().unwrap_err();
assert_eq!(error.status, Status::NotImplemented);
}

Expand Down Expand Up @@ -285,9 +286,10 @@ fn test_statement_cancel() {
let mut driver = get_driver();
let database = get_database(&mut driver);
let mut connection = database.new_connection().unwrap();
let mut statement = connection.new_statement().unwrap();
let statement = connection.new_statement().unwrap();

let error = statement.cancel().unwrap_err();
let handle = statement.get_cancel_handle();
let error = handle.try_cancel().unwrap_err();
assert_eq!(error.status, Status::NotImplemented);
}

Expand Down
Loading
Loading