Skip to content

Commit a9469a2

Browse files
authored
Fix sqlite connection reinitialization (RustPython#6288)
* Fix sqlite connection reinitialization * Align sqlite connection reinit with CPython * Enable sqlite test_connection_bad_reinit * Fix sqlite reinit flag without threading * Use stronger memory ordering for initialized flag synchronization
1 parent 567fb4d commit a9469a2

File tree

2 files changed

+64
-26
lines changed

2 files changed

+64
-26
lines changed

Lib/test/test_sqlite3/test_dbapi.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -573,8 +573,6 @@ def test_connection_reinit(self):
573573
self.assertTrue(all(isinstance(r, sqlite.Row) for r in rows))
574574
self.assertEqual([r[0] for r in rows], ["2", "3"])
575575

576-
# TODO: RUSTPYTHON
577-
@unittest.expectedFailure
578576
def test_connection_bad_reinit(self):
579577
cx = sqlite.connect(":memory:")
580578
with cx:

crates/stdlib/src/sqlite.rs

Lines changed: 64 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -833,10 +833,11 @@ mod _sqlite {
833833
#[derive(PyPayload)]
834834
struct Connection {
835835
db: PyMutex<Option<Sqlite>>,
836-
detect_types: c_int,
836+
initialized: PyAtomic<bool>,
837+
detect_types: PyAtomic<c_int>,
837838
isolation_level: PyAtomicRef<Option<PyStr>>,
838-
check_same_thread: bool,
839-
thread_ident: ThreadId,
839+
check_same_thread: PyAtomic<bool>,
840+
thread_ident: PyMutex<ThreadId>, // TODO: Use atomic
840841
row_factory: PyAtomicRef<Option<PyObject>>,
841842
text_factory: PyAtomicRef<PyObject>,
842843
}
@@ -865,12 +866,15 @@ mod _sqlite {
865866
None
866867
};
867868

869+
let initialized = db.is_some();
870+
868871
let conn = Self {
869872
db: PyMutex::new(db),
870-
detect_types: args.detect_types,
873+
initialized: Radium::new(initialized),
874+
detect_types: Radium::new(args.detect_types),
871875
isolation_level: PyAtomicRef::from(args.isolation_level),
872-
check_same_thread: args.check_same_thread,
873-
thread_ident: std::thread::current().id(),
876+
check_same_thread: Radium::new(args.check_same_thread),
877+
thread_ident: PyMutex::new(std::thread::current().id()),
874878
row_factory: PyAtomicRef::from(None),
875879
text_factory: PyAtomicRef::from(text_factory),
876880
};
@@ -899,20 +903,51 @@ mod _sqlite {
899903
type Args = ConnectArgs;
900904

901905
fn init(zelf: PyRef<Self>, args: Self::Args, vm: &VirtualMachine) -> PyResult<()> {
902-
let mut guard = zelf.db.lock();
903-
if guard.is_some() {
904-
// Already initialized
905-
return Ok(());
906+
let was_initialized = Radium::swap(&zelf.initialized, false, Ordering::AcqRel);
907+
908+
// Reset factories to their defaults, matching CPython's behavior.
909+
zelf.reset_factories(vm);
910+
911+
if was_initialized {
912+
zelf.drop_db();
906913
}
907914

915+
// Attempt to open the new database before mutating other state so failures leave
916+
// the connection uninitialized (and subsequent operations raise ProgrammingError).
908917
let db = Self::initialize_db(&args, vm)?;
918+
919+
let ConnectArgs {
920+
detect_types,
921+
isolation_level,
922+
check_same_thread,
923+
..
924+
} = args;
925+
926+
zelf.detect_types.store(detect_types, Ordering::Relaxed);
927+
zelf.check_same_thread
928+
.store(check_same_thread, Ordering::Relaxed);
929+
*zelf.thread_ident.lock() = std::thread::current().id();
930+
let _ = unsafe { zelf.isolation_level.swap(isolation_level) };
931+
932+
let mut guard = zelf.db.lock();
909933
*guard = Some(db);
934+
Radium::store(&zelf.initialized, true, Ordering::Release);
910935
Ok(())
911936
}
912937
}
913938

914939
#[pyclass(with(Constructor, Callable, Initializer), flags(BASETYPE))]
915940
impl Connection {
941+
fn drop_db(&self) {
942+
self.db.lock().take();
943+
}
944+
945+
fn reset_factories(&self, vm: &VirtualMachine) {
946+
let default_text_factory = PyStr::class(&vm.ctx).to_owned().into_object();
947+
let _ = unsafe { self.row_factory.swap(None) };
948+
let _ = unsafe { self.text_factory.swap(default_text_factory) };
949+
}
950+
916951
fn initialize_db(args: &ConnectArgs, vm: &VirtualMachine) -> PyResult<Sqlite> {
917952
let path = args.database.to_cstring(vm)?;
918953
let db = Sqlite::from(SqliteRaw::open(path.as_ptr(), args.uri, vm)?);
@@ -1003,7 +1038,7 @@ mod _sqlite {
10031038
#[pymethod]
10041039
fn close(&self, vm: &VirtualMachine) -> PyResult<()> {
10051040
self.check_thread(vm)?;
1006-
self.db.lock().take();
1041+
self.drop_db();
10071042
Ok(())
10081043
}
10091044

@@ -1450,15 +1485,17 @@ mod _sqlite {
14501485
}
14511486

14521487
fn check_thread(&self, vm: &VirtualMachine) -> PyResult<()> {
1453-
if self.check_same_thread && (std::thread::current().id() != self.thread_ident) {
1454-
Err(new_programming_error(
1455-
vm,
1456-
"SQLite objects created in a thread can only be used in that same thread."
1457-
.to_owned(),
1458-
))
1459-
} else {
1460-
Ok(())
1488+
if self.check_same_thread.load(Ordering::Relaxed) {
1489+
let creator_id = *self.thread_ident.lock();
1490+
if std::thread::current().id() != creator_id {
1491+
return Err(new_programming_error(
1492+
vm,
1493+
"SQLite objects created in a thread can only be used in that same thread."
1494+
.to_owned(),
1495+
));
1496+
}
14611497
}
1498+
Ok(())
14621499
}
14631500

14641501
#[pygetset]
@@ -1632,7 +1669,8 @@ mod _sqlite {
16321669

16331670
inner.row_cast_map = zelf.build_row_cast_map(&st, vm)?;
16341671

1635-
inner.description = st.columns_description(zelf.connection.detect_types, vm)?;
1672+
let detect_types = zelf.connection.detect_types.load(Ordering::Relaxed);
1673+
inner.description = st.columns_description(detect_types, vm)?;
16361674

16371675
if ret == SQLITE_ROW {
16381676
drop(st);
@@ -1680,7 +1718,8 @@ mod _sqlite {
16801718
));
16811719
}
16821720

1683-
inner.description = st.columns_description(zelf.connection.detect_types, vm)?;
1721+
let detect_types = zelf.connection.detect_types.load(Ordering::Relaxed);
1722+
inner.description = st.columns_description(detect_types, vm)?;
16841723

16851724
inner.rowcount = if stmt.is_dml { 0 } else { -1 };
16861725

@@ -1845,15 +1884,16 @@ mod _sqlite {
18451884
st: &SqliteStatementRaw,
18461885
vm: &VirtualMachine,
18471886
) -> PyResult<Vec<Option<PyObjectRef>>> {
1848-
if self.connection.detect_types == 0 {
1887+
let detect_types = self.connection.detect_types.load(Ordering::Relaxed);
1888+
if detect_types == 0 {
18491889
return Ok(vec![]);
18501890
}
18511891

18521892
let mut cast_map = vec![];
18531893
let num_cols = st.column_count();
18541894

18551895
for i in 0..num_cols {
1856-
if self.connection.detect_types & PARSE_COLNAMES != 0 {
1896+
if detect_types & PARSE_COLNAMES != 0 {
18571897
let col_name = st.column_name(i);
18581898
let col_name = ptr_to_str(col_name, vm)?;
18591899
let col_name = col_name
@@ -1868,7 +1908,7 @@ mod _sqlite {
18681908
continue;
18691909
}
18701910
}
1871-
if self.connection.detect_types & PARSE_DECLTYPES != 0 {
1911+
if detect_types & PARSE_DECLTYPES != 0 {
18721912
let decltype = st.column_decltype(i);
18731913
let decltype = ptr_to_str(decltype, vm)?;
18741914
if let Some(decltype) = decltype.split_terminator(&[' ', '(']).next() {

0 commit comments

Comments
 (0)