@@ -833,10 +833,11 @@ mod _sqlite {
833833 #[ derive( PyPayload ) ]
834834 struct Connection {
835835 db : PyMutex < Option < Sqlite > > ,
836- detect_types : c_int ,
836+ initialized : PyAtomic < bool > ,
837+ detect_types : PyAtomic < c_int > ,
837838 isolation_level : PyAtomicRef < Option < PyStr > > ,
838- check_same_thread : bool ,
839- thread_ident : ThreadId ,
839+ check_same_thread : PyAtomic < bool > ,
840+ thread_ident : PyMutex < ThreadId > , // TODO: Use atomic
840841 row_factory : PyAtomicRef < Option < PyObject > > ,
841842 text_factory : PyAtomicRef < PyObject > ,
842843 }
@@ -865,12 +866,15 @@ mod _sqlite {
865866 None
866867 } ;
867868
869+ let initialized = db. is_some ( ) ;
870+
868871 let conn = Self {
869872 db : PyMutex :: new ( db) ,
870- detect_types : args. detect_types ,
873+ initialized : Radium :: new ( initialized) ,
874+ detect_types : Radium :: new ( args. detect_types ) ,
871875 isolation_level : PyAtomicRef :: from ( args. isolation_level ) ,
872- check_same_thread : args. check_same_thread ,
873- thread_ident : std:: thread:: current ( ) . id ( ) ,
876+ check_same_thread : Radium :: new ( args. check_same_thread ) ,
877+ thread_ident : PyMutex :: new ( std:: thread:: current ( ) . id ( ) ) ,
874878 row_factory : PyAtomicRef :: from ( None ) ,
875879 text_factory : PyAtomicRef :: from ( text_factory) ,
876880 } ;
@@ -899,20 +903,51 @@ mod _sqlite {
899903 type Args = ConnectArgs ;
900904
901905 fn init ( zelf : PyRef < Self > , args : Self :: Args , vm : & VirtualMachine ) -> PyResult < ( ) > {
902- let mut guard = zelf. db . lock ( ) ;
903- if guard. is_some ( ) {
904- // Already initialized
905- return Ok ( ( ) ) ;
906+ let was_initialized = Radium :: swap ( & zelf. initialized , false , Ordering :: AcqRel ) ;
907+
908+ // Reset factories to their defaults, matching CPython's behavior.
909+ zelf. reset_factories ( vm) ;
910+
911+ if was_initialized {
912+ zelf. drop_db ( ) ;
906913 }
907914
915+ // Attempt to open the new database before mutating other state so failures leave
916+ // the connection uninitialized (and subsequent operations raise ProgrammingError).
908917 let db = Self :: initialize_db ( & args, vm) ?;
918+
919+ let ConnectArgs {
920+ detect_types,
921+ isolation_level,
922+ check_same_thread,
923+ ..
924+ } = args;
925+
926+ zelf. detect_types . store ( detect_types, Ordering :: Relaxed ) ;
927+ zelf. check_same_thread
928+ . store ( check_same_thread, Ordering :: Relaxed ) ;
929+ * zelf. thread_ident . lock ( ) = std:: thread:: current ( ) . id ( ) ;
930+ let _ = unsafe { zelf. isolation_level . swap ( isolation_level) } ;
931+
932+ let mut guard = zelf. db . lock ( ) ;
909933 * guard = Some ( db) ;
934+ Radium :: store ( & zelf. initialized , true , Ordering :: Release ) ;
910935 Ok ( ( ) )
911936 }
912937 }
913938
914939 #[ pyclass( with( Constructor , Callable , Initializer ) , flags( BASETYPE ) ) ]
915940 impl Connection {
941+ fn drop_db ( & self ) {
942+ self . db . lock ( ) . take ( ) ;
943+ }
944+
945+ fn reset_factories ( & self , vm : & VirtualMachine ) {
946+ let default_text_factory = PyStr :: class ( & vm. ctx ) . to_owned ( ) . into_object ( ) ;
947+ let _ = unsafe { self . row_factory . swap ( None ) } ;
948+ let _ = unsafe { self . text_factory . swap ( default_text_factory) } ;
949+ }
950+
916951 fn initialize_db ( args : & ConnectArgs , vm : & VirtualMachine ) -> PyResult < Sqlite > {
917952 let path = args. database . to_cstring ( vm) ?;
918953 let db = Sqlite :: from ( SqliteRaw :: open ( path. as_ptr ( ) , args. uri , vm) ?) ;
@@ -1003,7 +1038,7 @@ mod _sqlite {
10031038 #[ pymethod]
10041039 fn close ( & self , vm : & VirtualMachine ) -> PyResult < ( ) > {
10051040 self . check_thread ( vm) ?;
1006- self . db . lock ( ) . take ( ) ;
1041+ self . drop_db ( ) ;
10071042 Ok ( ( ) )
10081043 }
10091044
@@ -1450,15 +1485,17 @@ mod _sqlite {
14501485 }
14511486
14521487 fn check_thread ( & self , vm : & VirtualMachine ) -> PyResult < ( ) > {
1453- if self . check_same_thread && ( std:: thread:: current ( ) . id ( ) != self . thread_ident ) {
1454- Err ( new_programming_error (
1455- vm,
1456- "SQLite objects created in a thread can only be used in that same thread."
1457- . to_owned ( ) ,
1458- ) )
1459- } else {
1460- Ok ( ( ) )
1488+ if self . check_same_thread . load ( Ordering :: Relaxed ) {
1489+ let creator_id = * self . thread_ident . lock ( ) ;
1490+ if std:: thread:: current ( ) . id ( ) != creator_id {
1491+ return Err ( new_programming_error (
1492+ vm,
1493+ "SQLite objects created in a thread can only be used in that same thread."
1494+ . to_owned ( ) ,
1495+ ) ) ;
1496+ }
14611497 }
1498+ Ok ( ( ) )
14621499 }
14631500
14641501 #[ pygetset]
@@ -1632,7 +1669,8 @@ mod _sqlite {
16321669
16331670 inner. row_cast_map = zelf. build_row_cast_map ( & st, vm) ?;
16341671
1635- inner. description = st. columns_description ( zelf. connection . detect_types , vm) ?;
1672+ let detect_types = zelf. connection . detect_types . load ( Ordering :: Relaxed ) ;
1673+ inner. description = st. columns_description ( detect_types, vm) ?;
16361674
16371675 if ret == SQLITE_ROW {
16381676 drop ( st) ;
@@ -1680,7 +1718,8 @@ mod _sqlite {
16801718 ) ) ;
16811719 }
16821720
1683- inner. description = st. columns_description ( zelf. connection . detect_types , vm) ?;
1721+ let detect_types = zelf. connection . detect_types . load ( Ordering :: Relaxed ) ;
1722+ inner. description = st. columns_description ( detect_types, vm) ?;
16841723
16851724 inner. rowcount = if stmt. is_dml { 0 } else { -1 } ;
16861725
@@ -1845,15 +1884,16 @@ mod _sqlite {
18451884 st : & SqliteStatementRaw ,
18461885 vm : & VirtualMachine ,
18471886 ) -> PyResult < Vec < Option < PyObjectRef > > > {
1848- if self . connection . detect_types == 0 {
1887+ let detect_types = self . connection . detect_types . load ( Ordering :: Relaxed ) ;
1888+ if detect_types == 0 {
18491889 return Ok ( vec ! [ ] ) ;
18501890 }
18511891
18521892 let mut cast_map = vec ! [ ] ;
18531893 let num_cols = st. column_count ( ) ;
18541894
18551895 for i in 0 ..num_cols {
1856- if self . connection . detect_types & PARSE_COLNAMES != 0 {
1896+ if detect_types & PARSE_COLNAMES != 0 {
18571897 let col_name = st. column_name ( i) ;
18581898 let col_name = ptr_to_str ( col_name, vm) ?;
18591899 let col_name = col_name
@@ -1868,7 +1908,7 @@ mod _sqlite {
18681908 continue ;
18691909 }
18701910 }
1871- if self . connection . detect_types & PARSE_DECLTYPES != 0 {
1911+ if detect_types & PARSE_DECLTYPES != 0 {
18721912 let decltype = st. column_decltype ( i) ;
18731913 let decltype = ptr_to_str ( decltype, vm) ?;
18741914 if let Some ( decltype) = decltype. split_terminator ( & [ ' ' , '(' ] ) . next ( ) {
0 commit comments