Skip to content
Open
23 changes: 23 additions & 0 deletions libs/opsqueue_python/python/opsqueue/producer.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@
]


class LookupIdsWithEmptyStrategicMetadataError(Exception):
pass


class ProducerClient:
"""
Opsqueue producer client. Allows sending of large collections of operations ('submissions')
Expand Down Expand Up @@ -367,6 +371,25 @@ def lookup_submission_id_by_prefix(self, prefix: str) -> SubmissionId | None:
"""
return self.inner.lookup_submission_id_by_prefix(prefix)

def lookup_submission_ids_by_strategic_metadata(
self, strategic_metadata: dict[str, int]
) -> list[SubmissionId]:
"""Attempts to find in-progress submissions where the strategic metadata
of that submission includes all of the key-value pairs of the given
'strategic_metadata'. A matching submission must include all of the
given key-value pairs, but it may also contain other key-value pairs.

Raises:
- `LookupIdsWithEmptyStrategicMetadataError` if the provided
'strategic_metadata' contained no key-value pairs to look for.

"""
if len(strategic_metadata) == 0:
raise LookupIdsWithEmptyStrategicMetadataError()
return self.inner.lookup_submission_ids_by_strategic_metadata( # type: ignore[no-any-return]
strategic_metadata
)

def is_completed(self, submission_id: SubmissionId) -> bool:
raise NotImplementedError

Expand Down
18 changes: 18 additions & 0 deletions libs/opsqueue_python/src/producer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,24 @@ impl ProducerClient {
})
}

/// Attempts to find the IDs of submission matching ALL key-values pairs of
/// the given strategic metadata.
pub fn lookup_submission_ids_by_strategic_metadata(
&self,
py: Python<'_>,
strategic_metadata: StrategicMetadataMap,
) -> CPyResult<Vec<SubmissionId>, E<FatalPythonException, InternalProducerClientError>> {
py.allow_threads(|| {
self.block_unless_interrupted(async {
self.producer_client
.lookup_submission_ids_by_strategic_metadata(&strategic_metadata)
.await
.map(|res| res.into_iter().map(Into::into).collect())
.map_err(|e| CError(R(e)))
})
})
}

/// Directly inserts a submission without sending the chunks to GCS
/// (but immediately embedding them in the DB).
/// NOTE: This does not support StrategicMetadata currently
Expand Down
52 changes: 52 additions & 0 deletions libs/opsqueue_python/tests/test_roundtrip.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from collections.abc import Iterator, Sequence
from opsqueue.producer import (
LookupIdsWithEmptyStrategicMetadataError,
SubmissionId,
ProducerClient,
SubmissionCompleted,
Expand Down Expand Up @@ -508,3 +509,54 @@ def consume(x: int) -> int | None:
with pytest.raises(SubmissionFailedError) as exc_info:
producer_client.blocking_stream_completed_submission(submission_id)
assert exc_info.value.submission.chunks_done == len(chunks) - 1


def test_lookup_submission_ids_by_strategic_metadata(opsqueue: OpsqueueProcess) -> None:
"""Lookup of submission IDs should only match in progress submissions with
all pieces of strategic metadata.
"""
url = "file:///tmp/opsqueue/test_lookup_submission_ids_by_strategic_metadata"
producer_client = ProducerClient(f"localhost:{opsqueue.port}", url)
id_1 = producer_client.insert_submission(
[1], chunk_size=1, strategic_metadata={"foo": 1, "bar": 2, "wow": 3}
)
id_2 = producer_client.insert_submission(
[1], chunk_size=1, strategic_metadata={"foo": 1, "bar": 2, "moo": 3}
)
# Inserting some similar data to that above, which shouldn't get matched.
producer_client.insert_submission(
[1], chunk_size=1, strategic_metadata={"foo": 2, "bar": 1}
)

def test_lookup(
strategic_metadata: dict[str, int], expected_ids: list[int]
) -> None:
found_ids = producer_client.lookup_submission_ids_by_strategic_metadata(
strategic_metadata
)
assert isinstance(found_ids, list)
assert all(map(lambda x: isinstance(x, SubmissionId), found_ids))
assert found_ids == expected_ids

test_lookup({"foo": 1}, [id_1, id_2])
test_lookup({"foo": 1, "bar": 2}, [id_1, id_2])
test_lookup({"foo": 1, "MISS": 2}, [])
test_lookup({"wow": 3}, [id_1])

# Should only match in-progress submission.
producer_client.cancel_submission(id_1)
test_lookup({"foo": 1}, [id_2])


def test_lookup_submission_ids_by_empty_strategic_metadata(
opsqueue: OpsqueueProcess,
) -> None:
"""Lookup of submission IDs with empty strategic_metadata should throw a
LookupIdsWithEmptyStrategicMetadataError.
"""
url = "file:///tmp/opsqueue/test_lookup_submission_ids_by_empty_strategic_metadata"
producer_client = ProducerClient(f"localhost:{opsqueue.port}", url)
with pytest.raises(LookupIdsWithEmptyStrategicMetadataError):
producer_client.lookup_submission_ids_by_strategic_metadata({})
31 changes: 30 additions & 1 deletion opsqueue/src/common/submission.rs
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ pub mod db {
db::{Connection, True, WriterConnection, WriterPool},
};
use chunk::ChunkSize;
use sqlx::{query, query_as, Sqlite};
use sqlx::{query, query_as, QueryBuilder, Row, Sqlite};

use axum_prometheus::metrics::{counter, histogram};

Expand Down Expand Up @@ -527,6 +527,35 @@ pub mod db {
Ok(row.map(|row| row.id))
}

pub async fn lookup_ids_by_strategic_metadata(
strategic_metadata: StrategicMetadataMap,
mut conn: impl Connection,
) -> Result<Vec<SubmissionId>, DatabaseError> {
// SQLx currently only supports "WHERE X IN (a, ...)" queries for postgres:
// https://github.com/transact-rs/sqlx/blob/main/FAQ.md#how-can-i-do-a-select--where-foo-in--query
// So we workaround this by manually building the query, foregoing
// sqlx's nice type-checking.
let mut query_builder: QueryBuilder<Sqlite> = QueryBuilder::new(
"
SELECT submission_id
FROM submissions_metadata
INNER JOIN submissions on submissions.id = submission_id
WHERE (metadata_key, metadata_value) IN (
",
);
query_builder.push_values(strategic_metadata.iter(), |mut b, sm| {
b.push_bind(sm.0).push_bind(sm.1);
});
query_builder.push(") GROUP BY submission_id HAVING count(*) = ");
query_builder.push_bind(strategic_metadata.len() as i64);
query_builder.push(" ORDER BY submission_id");
let rows = query_builder.build().fetch_all(conn.get_inner()).await?;
Ok(rows
.into_iter()
.map(|row| row.get("submission_id"))
.collect())
}

#[tracing::instrument(skip(conn))]
pub async fn submission_status(
id: SubmissionId,
Expand Down
28 changes: 28 additions & 0 deletions opsqueue/src/producer/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use crate::{
errors::E::{L, R},
errors::{SubmissionNotCancellable, SubmissionNotFound},
submission::{SubmissionId, SubmissionStatus},
StrategicMetadataMap,
},
tracing::CarrierMap,
E,
Expand Down Expand Up @@ -226,6 +227,33 @@ impl Client {
.await
}

pub async fn lookup_submission_ids_by_strategic_metadata(
&self,
strategic_metadata: &StrategicMetadataMap,
) -> Result<Vec<SubmissionId>, InternalProducerClientError> {
(|| async {
let base_url = &self.base_url;
let resp = self
.http_client
.post(format!(
"{base_url}/submissions/lookup_ids_by_strategic_metadata"
))
.json(strategic_metadata)
.send()
.await?
.error_for_status()?;
let bytes = resp.bytes().await?;
let body = serde_json::from_slice(&bytes)?;
Ok(body)
})
.retry(retry_policy())
.when(InternalProducerClientError::is_ephemeral)
.notify(|err, dur| {
tracing::debug!("retrying error {err:?} with sleeping {dur:?}");
})
.await
}

/// Get the server's version from the `/version` endpoint.
///
/// A successful result will be the value of [`VERSION_CARGO_SEMVER`][crate::VERSION_CARGO_SEMVER]
Expand Down
16 changes: 16 additions & 0 deletions opsqueue/src/producer/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ use std::sync::Arc;

use crate::common::errors::E::{L, R};
use crate::common::submission::{self, SubmissionId};
use crate::common::StrategicMetadataMap;
use crate::db::{self, DBPools};
use axum::extract;
use axum::extract::{Path, State};
use axum::http::StatusCode;
use axum::response::{IntoResponse, Response};
Expand Down Expand Up @@ -60,6 +62,10 @@ impl ServerState {
"/submissions/lookup_id_by_prefix/{prefix}",
get(lookup_submission_id_by_prefix),
)
.route(
"/submissions/lookup_ids_by_strategic_metadata",
post(lookup_submission_ids_by_strategic_metadata),
)
.route("/submissions/{submission_id}", get(submission_status))
.route("/version", get(crate::server::version_endpoint)) // We're also exposing it here so the producer client can view it
.with_state(self)
Expand Down Expand Up @@ -133,6 +139,16 @@ async fn lookup_submission_id_by_prefix(
Ok(Json(submission_id))
}

async fn lookup_submission_ids_by_strategic_metadata(
State(state): State<ServerState>,
extract::Json(strategic_metadata): extract::Json<StrategicMetadataMap>,
) -> Result<Json<Vec<SubmissionId>>, ServerError> {
let mut conn = state.pool.reader_conn().await?;
let submission_ids =
submission::db::lookup_ids_by_strategic_metadata(strategic_metadata, &mut conn).await?;
Ok(Json(submission_ids))
}

#[tracing::instrument(level = "debug", skip(state))]
async fn insert_submission(
State(state): State<ServerState>,
Expand Down
Loading