Skip to content

Commit f61b62e

Browse files
authored
Ensure sqlite blob methods respect closed connections (RustPython#6290)
1 parent a9469a2 commit f61b62e

File tree

1 file changed

+19
-6
lines changed

1 file changed

+19
-6
lines changed

crates/stdlib/src/sqlite.rs

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2207,18 +2207,24 @@ mod _sqlite {
22072207
self.inner.lock().take();
22082208
}
22092209

2210+
fn ensure_connection_open(&self, vm: &VirtualMachine) -> PyResult<()> {
2211+
if self.connection.is_closed() {
2212+
Err(new_programming_error(
2213+
vm,
2214+
"Cannot operate on a closed database".to_owned(),
2215+
))
2216+
} else {
2217+
Ok(())
2218+
}
2219+
}
2220+
22102221
#[pymethod]
22112222
fn read(
22122223
&self,
22132224
length: OptionalArg<c_int>,
22142225
vm: &VirtualMachine,
22152226
) -> PyResult<PyRef<PyBytes>> {
2216-
if self.connection.is_closed() {
2217-
return Err(new_programming_error(
2218-
vm,
2219-
"Cannot operate on a closed database".to_owned(),
2220-
));
2221-
}
2227+
self.ensure_connection_open(vm)?;
22222228

22232229
let mut length = length.unwrap_or(-1);
22242230
let mut inner = self.inner(vm)?;
@@ -2245,6 +2251,7 @@ mod _sqlite {
22452251

22462252
#[pymethod]
22472253
fn write(&self, data: PyBuffer, vm: &VirtualMachine) -> PyResult<()> {
2254+
self.ensure_connection_open(vm)?;
22482255
let mut inner = self.inner(vm)?;
22492256
let blob_len = inner.blob.bytes();
22502257
let length = Self::expect_write(blob_len, data.desc.len, inner.offset, vm)?;
@@ -2260,6 +2267,7 @@ mod _sqlite {
22602267

22612268
#[pymethod]
22622269
fn tell(&self, vm: &VirtualMachine) -> PyResult<c_int> {
2270+
self.ensure_connection_open(vm)?;
22632271
self.inner(vm).map(|x| x.offset)
22642272
}
22652273

@@ -2270,6 +2278,7 @@ mod _sqlite {
22702278
origin: OptionalArg<c_int>,
22712279
vm: &VirtualMachine,
22722280
) -> PyResult<()> {
2281+
self.ensure_connection_open(vm)?;
22732282
let origin = origin.unwrap_or(libc::SEEK_SET);
22742283
let mut inner = self.inner(vm)?;
22752284
let blob_len = inner.blob.bytes();
@@ -2299,12 +2308,14 @@ mod _sqlite {
22992308

23002309
#[pymethod]
23012310
fn __enter__(zelf: PyRef<Self>, vm: &VirtualMachine) -> PyResult<PyRef<Self>> {
2311+
zelf.ensure_connection_open(vm)?;
23022312
let _ = zelf.inner(vm)?;
23032313
Ok(zelf)
23042314
}
23052315

23062316
#[pymethod]
23072317
fn __exit__(&self, _args: FuncArgs, vm: &VirtualMachine) -> PyResult<()> {
2318+
self.ensure_connection_open(vm)?;
23082319
let _ = self.inner(vm)?;
23092320
self.close();
23102321
Ok(())
@@ -2351,6 +2362,7 @@ mod _sqlite {
23512362
}
23522363

23532364
fn subscript(&self, needle: &PyObject, vm: &VirtualMachine) -> PyResult {
2365+
self.ensure_connection_open(vm)?;
23542366
let inner = self.inner(vm)?;
23552367
if let Some(index) = needle.try_index_opt(vm) {
23562368
let blob_len = inner.blob.bytes();
@@ -2396,6 +2408,7 @@ mod _sqlite {
23962408
let Some(value) = value else {
23972409
return Err(vm.new_type_error("Blob doesn't support slice deletion"));
23982410
};
2411+
self.ensure_connection_open(vm)?;
23992412
let inner = self.inner(vm)?;
24002413

24012414
if let Some(index) = needle.try_index_opt(vm) {

0 commit comments

Comments
 (0)