diff --git a/import-automation/workflow/ingestion-helper/embedding_utils.py b/import-automation/workflow/ingestion-helper/embedding_utils.py index 9c22d5f938..442e486d3f 100644 --- a/import-automation/workflow/ingestion-helper/embedding_utils.py +++ b/import-automation/workflow/ingestion-helper/embedding_utils.py @@ -43,7 +43,7 @@ def get_latest_lock_timestamp(database): raise return None -def get_updated_nodes(database, timestamp, node_types): +def get_updated_nodes(database, timestamp, node_types, timeout): """Gets subject_ids and names from Node table where last_update_timestamp > timestamp. Yields results to avoid loading all into memory. @@ -51,6 +51,7 @@ def get_updated_nodes(database, timestamp, node_types): database: google.cloud.spanner.Database object. timestamp: datetime object to filter by. node_types: A list of strings representing the node types to filter by. + timeout: Timeout for the spanner client to execute queries. Yields: Dictionaries containing subject_id and name. @@ -78,7 +79,7 @@ def get_updated_nodes(database, timestamp, node_types): try: with database.snapshot() as snapshot: - results = snapshot.execute_sql(updated_node_sql, params=params, param_types=param_types, timeout=300) + results = snapshot.execute_sql(updated_node_sql, params=params, param_types=param_types, timeout=timeout) fields = None for row in results: if fields is None: @@ -104,7 +105,7 @@ def filter_and_convert_nodes(nodes_generator): yield (node.get("subject_id"), node.get("name"), node.get("types")) -def generate_embeddings_partitioned(database, nodes_generator): +def generate_embeddings_partitioned(database, nodes_generator, timeout): """Generates embeddings in batches using standard transactions. Processes nodes in chunks of 500 to avoid transaction size limits. Accepts a generator to avoid loading all nodes into memory. @@ -112,6 +113,7 @@ def generate_embeddings_partitioned(database, nodes_generator): Args: database: google.cloud.spanner.Database object. nodes_generator: A generator yielding tuples containing (subject_id, embedding_content). + timeout: Timeout for the spanner client to execute queries. Returns: The number of affected rows. @@ -149,7 +151,7 @@ def chunked(iterable, n): param_types = {"nodes": Array(struct_type)} def _execute_dml(transaction): - return transaction.execute_update(embeddings_sql, params=params, param_types=param_types, timeout=300) + return transaction.execute_update(embeddings_sql, params=params, param_types=param_types, timeout=timeout) try: row_count = database.run_in_transaction(_execute_dml) diff --git a/import-automation/workflow/ingestion-helper/main.py b/import-automation/workflow/ingestion-helper/main.py index 8ec128c77c..f8595d6fb7 100644 --- a/import-automation/workflow/ingestion-helper/main.py +++ b/import-automation/workflow/ingestion-helper/main.py @@ -45,6 +45,9 @@ 'is_base_dc', os.environ.get('IS_BASE_DC', 'true').lower() == 'true', 'Is base DC') +flags.DEFINE_integer( + 'timeout', int(os.environ.get('TIMEOUT', 1700)), + 'Timeout in seconds for spanner client to execute queries') if not FLAGS.is_parsed(): FLAGS(['ingestion_helper']) @@ -246,9 +249,9 @@ def ingestion_helper(request): try: logging.info(f"Job started. Fetching all nodes for types: {node_types}") timestamp = get_latest_lock_timestamp(spanner.database) - nodes = get_updated_nodes(spanner.database, timestamp, node_types) + nodes = get_updated_nodes(spanner.database, timestamp, node_types, timeout=FLAGS.timeout) converted_nodes = filter_and_convert_nodes(nodes) - affected_rows = generate_embeddings_partitioned(spanner.database, converted_nodes) + affected_rows = generate_embeddings_partitioned(spanner.database, converted_nodes, timeout=FLAGS.timeout) return (f"OK [Affected rows: {affected_rows}]", 200) except Exception as e: logging.error(f"Embedding ingestion failed: {e}")