Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions import-automation/workflow/ingestion-helper/embedding_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,15 @@ def get_latest_lock_timestamp(database):
raise
return None

def get_updated_nodes(database, timestamp, node_types):
def get_updated_nodes(database, timestamp, node_types, timeout):
"""Gets subject_ids and names from Node table where last_update_timestamp > timestamp.
Yields results to avoid loading all into memory.

Args:
database: google.cloud.spanner.Database object.
timestamp: datetime object to filter by.
node_types: A list of strings representing the node types to filter by.
timeout: Timeout for the spanner client to execute queries.

Yields:
Dictionaries containing subject_id and name.
Expand Down Expand Up @@ -78,7 +79,7 @@ def get_updated_nodes(database, timestamp, node_types):

try:
with database.snapshot() as snapshot:
results = snapshot.execute_sql(updated_node_sql, params=params, param_types=param_types, timeout=300)
results = snapshot.execute_sql(updated_node_sql, params=params, param_types=param_types, timeout=timeout)
fields = None
for row in results:
if fields is None:
Expand All @@ -104,14 +105,15 @@ def filter_and_convert_nodes(nodes_generator):
yield (node.get("subject_id"), node.get("name"), node.get("types"))


def generate_embeddings_partitioned(database, nodes_generator):
def generate_embeddings_partitioned(database, nodes_generator, timeout):
"""Generates embeddings in batches using standard transactions.
Processes nodes in chunks of 500 to avoid transaction size limits.
Accepts a generator to avoid loading all nodes into memory.

Args:
database: google.cloud.spanner.Database object.
nodes_generator: A generator yielding tuples containing (subject_id, embedding_content).
timeout: Timeout for the spanner client to execute queries.

Returns:
The number of affected rows.
Expand Down Expand Up @@ -149,7 +151,7 @@ def chunked(iterable, n):
param_types = {"nodes": Array(struct_type)}

def _execute_dml(transaction):
return transaction.execute_update(embeddings_sql, params=params, param_types=param_types, timeout=300)
return transaction.execute_update(embeddings_sql, params=params, param_types=param_types, timeout=timeout)

try:
row_count = database.run_in_transaction(_execute_dml)
Expand Down
7 changes: 5 additions & 2 deletions import-automation/workflow/ingestion-helper/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@
'is_base_dc',
os.environ.get('IS_BASE_DC', 'true').lower() == 'true',
'Is base DC')
flags.DEFINE_integer(
'timeout', int(os.environ.get('TIMEOUT', 1700)),
'Timeout in seconds for spanner client to execute queries')

if not FLAGS.is_parsed():
FLAGS(['ingestion_helper'])
Expand Down Expand Up @@ -246,9 +249,9 @@ def ingestion_helper(request):
try:
logging.info(f"Job started. Fetching all nodes for types: {node_types}")
timestamp = get_latest_lock_timestamp(spanner.database)
nodes = get_updated_nodes(spanner.database, timestamp, node_types)
nodes = get_updated_nodes(spanner.database, timestamp, node_types, timeout=FLAGS.timeout)
converted_nodes = filter_and_convert_nodes(nodes)
affected_rows = generate_embeddings_partitioned(spanner.database, converted_nodes)
affected_rows = generate_embeddings_partitioned(spanner.database, converted_nodes, timeout=FLAGS.timeout)
return (f"OK [Affected rows: {affected_rows}]", 200)
except Exception as e:
logging.error(f"Embedding ingestion failed: {e}")
Expand Down
Loading