From c59b24b54b81e1a0e0d4dcc62b3b6497dbeaca3f Mon Sep 17 00:00:00 2001 From: ParticularlyPythonicBS Date: Wed, 25 Mar 2026 11:38:43 -0400 Subject: [PATCH 1/5] Refactor to a unified v3 to v4 master migrator - Created master_migration.py which streamlines direct structure translation from v3 and v3.1 to v4. - Refactored run_all_v4_migrations.py to expose a reusable batch migrations function. - Updated cli.py to seamlessly accept directory migrations leveraging the new master hook. - Fixed scope and schema edge cases, ensuring robust validation across the tests. --- temoa/cli.py | 69 ++-- temoa/utilities/master_migration.py | 429 +++++++++++++++++++++++ temoa/utilities/run_all_v4_migrations.py | 139 ++++---- tests/test_cli.py | 19 +- tests/test_v4_migration.py | 18 +- 5 files changed, 561 insertions(+), 113 deletions(-) create mode 100644 temoa/utilities/master_migration.py diff --git a/temoa/cli.py b/temoa/cli.py index 45aea3a3..b6766774 100644 --- a/temoa/cli.py +++ b/temoa/cli.py @@ -1,4 +1,3 @@ -import argparse import logging import shutil from datetime import UTC, datetime @@ -16,7 +15,8 @@ from temoa._internal.temoa_sequencer import TemoaSequencer from temoa.core.config import TemoaConfig from temoa.core.modes import TemoaMode -from temoa.utilities import db_migration_v3_1_to_v4, sql_migration_v3_1_to_v4 +from temoa.utilities import master_migration +from temoa.utilities.run_all_v4_migrations import run_migrations # ============================================================================= # Logging & Helper Setup @@ -399,7 +399,7 @@ def migrate( input_path: Annotated[ Path, typer.Argument( - help='Path to input file to migrate (SQL dump or SQLite DB).', + help='Path to input file or directory to migrate (SQL dump or SQLite DB).', exists=True, resolve_path=True, ), @@ -431,7 +431,7 @@ def migrate( debug: Annotated[bool, typer.Option('--debug', '-d', help='Enable debug output.')] = False, ) -> None: """ - Migrate a single Temoa database file (SQL dump or SQLite DB) from v3.1 to v4 format. + Migrate a Temoa database file (SQL dump or SQLite DB) or directory from v3 to v4 format. """ if schema_path is None: schema_path = get_default_schema() @@ -439,21 +439,40 @@ def migrate( rich.print(f'[red]Error: Schema file {schema_path} does not exist or is not a file.[/red]') raise typer.Exit(1) - # Validate that input_path is a file, not a directory - if not input_path.is_file(): - rich.print(f'[red]Error: Input path must be a file, not a directory: {input_path}[/red]') - raise typer.Exit(1) + # 1. Directory Migration + if input_path.is_dir(): + if output_path is not None: + rich.print( + '[yellow]Warning: --output is ignored when migrating a directory. Originals are overwritten after backup.[/yellow]' + ) + + migration_script = Path(__file__).parent / 'utilities' / 'master_migration.py' + if not silent: + rich.print(f'[green]Batch migrating directory: {input_path}[/green]') + + try: + run_migrations( + input_dir=input_path, + migration_script=migration_script, + schema_path=schema_path, + dry_run=False, + ) + if not silent: + rich.print(f'[green]Directory migration completed for {input_path}[/green]') + except Exception as e: + logger.exception('Directory migration failed') + rich.print(f'[red]Directory migration failed for {input_path}: {e}[/red]') + raise typer.Exit(1) from e + return + # 2. Single File Migration ext = input_path.suffix.lower() - # Determine the effective output directory and file effective_output_dir: Path final_output_file: Path if output_path: - # If explicit output_path is provided, its parent is the desired directory effective_output_dir = output_path.parent - # Ensure the explicitly provided output_path parent exists try: effective_output_dir.mkdir(parents=True, exist_ok=True) except OSError as e: @@ -463,12 +482,10 @@ def migrate( raise typer.Exit(1) from e final_output_file = effective_output_dir / output_path.name else: - # Try to use the input file's directory input_dir = input_path.parent if _is_writable(input_dir): effective_output_dir = input_dir else: - # Fallback to current working directory if input_dir is not writable current_dir = Path.cwd() if _is_writable(current_dir): effective_output_dir = current_dir @@ -485,7 +502,6 @@ def migrate( ) raise typer.Exit(1) - # Ensure the chosen output directory exists try: effective_output_dir.mkdir(parents=True, exist_ok=True) except OSError as e: @@ -495,26 +511,17 @@ def migrate( ) raise typer.Exit(1) from e - # For auto-output, derive filename from input_path, place in effective_output_dir - # Determine output file extension based on migration type if migration_type == 'db' or (migration_type is None and ext in ['.db', '.sqlite']): - # If migrating to DB, output should be .sqlite final_output_file = effective_output_dir / (input_path.stem + '_v4.sqlite') else: - # Default to .sql if migrating SQL dump or type 'auto' for .sql input final_output_file = effective_output_dir / (input_path.stem + '_v4.sql') # --- Execute the migration based on type --- if migration_type == 'sql' or (migration_type is None and ext == '.sql'): - # SQL dump to SQL dump migration - args_namespace = argparse.Namespace( - input=str(input_path), - schema=str(schema_path), - output=str(final_output_file), - debug=debug, - ) try: - sql_migration_v3_1_to_v4.migrate_dump_to_sqlite(args_namespace) + master_migration.migrate_sql_dump( + source_path=input_path, schema_path=schema_path, output_path=final_output_file + ) if not silent: rich.print(f'[green]SQL dump migration completed: {final_output_file}[/green]') except Exception as e: @@ -524,14 +531,10 @@ def migrate( ) raise typer.Exit(1) from e elif migration_type == 'db' or (migration_type is None and ext in ['.db', '.sqlite']): - # SQLite DB to SQLite DB migration - args_namespace = argparse.Namespace( - source=str(input_path), - schema=str(schema_path), - out=str(final_output_file), - ) try: - db_migration_v3_1_to_v4.migrate_all(args_namespace) + master_migration.migrate_database( + source_path=input_path, schema_path=schema_path, output_path=final_output_file + ) if not silent: rich.print(f'[green]Database migration completed: {final_output_file}[/green]') except Exception as e: diff --git a/temoa/utilities/master_migration.py b/temoa/utilities/master_migration.py new file mode 100644 index 00000000..afd57016 --- /dev/null +++ b/temoa/utilities/master_migration.py @@ -0,0 +1,429 @@ +import argparse +import re +import sqlite3 +from pathlib import Path +from typing import Any + +# Mapping config +CUSTOM_MAP: dict[str, str] = { + 'TimeNext': 'time_manual', + 'CommodityDStreamProcess': 'commodity_down_stream_process', + 'commodityUStreamProcess': 'commodity_up_stream_process', + 'SegFrac': 'segment_fraction', + 'segfrac': 'segment_fraction', + 'MetaDataReal': 'metadata_real', + 'MetaData': 'metadata', + 'Myopicefficiency': 'myopic_efficiency', + 'DB_MAJOR': 'db_major', + 'DB_MINOR': 'db_minor', +} + +CUSTOM_EXACT_ONLY = {'time_season', 'time_season_sequential'} +CUSTOM_KEYS_SORTED = sorted( + [k for k in CUSTOM_MAP.keys() if k not in CUSTOM_EXACT_ONLY], key=lambda k: -len(k) +) + +OPERATOR_ADDED_TABLES = { + 'EmissionLimit': ('limit_emission', 'le'), + 'TechOutputSplit': ('limit_tech_output_split', 'ge'), + 'TechInputSplitAnnual': ('limit_tech_input_split_annual', 'ge'), + 'TechInputSplitAverage': ('limit_tech_input_split_annual', 'ge'), + 'TechInputSplit': ('limit_tech_input_split', 'ge'), + 'MinNewCapacityShare': ('limit_new_capacity_share', 'ge'), + 'MinNewCapacityGroupShare': ('limit_new_capacity_share', 'ge'), + 'MinNewCapacityGroup': ('limit_new_capacity', 'ge'), + 'MinNewCapacity': ('limit_new_capacity', 'ge'), + 'MinCapacityShare': ('limit_capacity_share', 'ge'), + 'MinCapacityGroup': ('limit_capacity', 'ge'), + 'MinCapacity': ('limit_capacity', 'ge'), + 'MinActivityShare': ('limit_activity_share', 'ge'), + 'MinActivityGroup': ('limit_activity', 'ge'), + 'MinActivity': ('limit_activity', 'ge'), + 'MaxNewCapacityShare': ('limit_new_capacity_share', 'le'), + 'MaxNewCapacityGroupShare': ('limit_new_capacity_share', 'le'), + 'MaxNewCapacityGroup': ('limit_new_capacity', 'le'), + 'MaxNewCapacity': ('limit_new_capacity', 'le'), + 'MaxCapacityShare': ('limit_capacity_share', 'le'), + 'MaxCapacityGroup': ('limit_capacity', 'le'), + 'MaxCapacity': ('limit_capacity', 'le'), + 'MaxActivityShare': ('limit_activity_share', 'le'), + 'MaxActivityGroup': ('limit_activity', 'le'), + 'MaxActivity': ('limit_activity', 'le'), + 'MaxResource': ('limit_resource', 'le'), +} + +PERIOD_TO_VINTAGE_TABLES = { + 'limit_new_capacity_share', + 'limit_new_capacity', +} + + +def to_snake_case(s: str) -> str: + if not s: + return s + if s == s.lower() and '_' in s: + return s + x = s.replace('-', '_').replace(' ', '_') + x = re.sub(r'(.)([A-Z][a-z]+)', r'\1_\2', x) + x = re.sub(r'([a-z0-9])([A-Z])', r'\1_\2', x) + x = re.sub(r'__+', '_', x) + return x.lower() + + +def map_token_no_cascade(token: str) -> str: + if not token: + return token + mapped_values = {v.lower() for v in CUSTOM_MAP.values()} + if token.lower() in mapped_values: + return token.lower() + if token in CUSTOM_MAP: + return CUSTOM_MAP[token].lower() + tl = token.lower() + for k, v in CUSTOM_MAP.items(): + if tl == k.lower(): + return v.lower() + if any(c.isupper() for c in token): + return to_snake_case(token) + orig = token + orig_lower = orig.lower() + replacements: list[tuple[str, str]] = [ + (k, CUSTOM_MAP[k]) for k in CUSTOM_KEYS_SORTED if k.lower() in orig_lower + ] + if replacements: + out = [] + i = 0 + length = len(orig) + while i < length: + matched = False + for key, repl in replacements: + kl = len(key) + if i + kl <= length and orig[i : i + kl].lower() == key.lower(): + out.append(repl) + i += kl + matched = True + break + if not matched: + out.append(orig[i]) + i += 1 + mapped_once = ''.join(out) + mapped_once = re.sub(r'__+', '_', mapped_once).lower() + return mapped_once + return to_snake_case(token) + + +def get_table_info(conn: sqlite3.Connection, table: str) -> list[tuple[Any, ...]]: + try: + return conn.execute(f'PRAGMA table_info({table});').fetchall() + except sqlite3.OperationalError: + return [] + + +def execute_v3_to_v4_migration(con_old: sqlite3.Connection, con_new: sqlite3.Connection) -> None: + old_tables = [ + r[0] + for r in con_old.execute("SELECT name FROM sqlite_master WHERE type='table'").fetchall() + ] + new_tables = [ + r[0] + for r in con_new.execute("SELECT name FROM sqlite_master WHERE type='table'").fetchall() + ] + total = 0 + + # 1. Handle operator-added tables + print('--- Migrating max/min tables to operator constraints ---') + for old_name, (new_name, operator) in OPERATOR_ADDED_TABLES.items(): + try: + data = con_old.execute(f'SELECT * FROM {old_name}').fetchall() + except sqlite3.OperationalError: + continue + if not data: + continue + + new_cols = [c[1] for c in get_table_info(con_new, new_name)] + if 'operator' not in new_cols: + continue + + op_index = new_cols.index('operator') + data = [(*row[0:op_index], operator, *row[op_index : len(new_cols) - 1]) for row in data] + + # Move period to vintage if applicable + if new_name in PERIOD_TO_VINTAGE_TABLES: + old_cols = [c[1] for c in get_table_info(con_old, old_name)] + if 'period' in old_cols and 'vintage' in new_cols: + period_index = old_cols.index('period') + vintage_index = new_cols.index('vintage') + data = [ + ( + *row[0:period_index], + *row[period_index + 1 : vintage_index + 1], + row[period_index], + *row[vintage_index + 1 :], + ) + for row in data + ] + + placeholders = ','.join(['?'] * len(data[0])) + query = f'INSERT OR REPLACE INTO {new_name} VALUES ({placeholders})' + con_new.executemany(query, data) + print(f'Migrated {len(data)} rows: {old_name} -> {new_name}') + total += len(data) + + # 2. Standard directory / copied tables + custom_handled_old_tables = { + 'MetaData', + 'MetaDataReal', + 'TimeSeason', + 'time_season', + 'time_of_day', + 'time_season_sequential', + 'TimeSeasonSequential', + 'TimeSegmentFraction', + 'LoanLifetimeTech', + 'CapacityFactorProcess', + }.union(OPERATOR_ADDED_TABLES.keys()) + + print('--- Executing standard table migrations ---') + for old in old_tables: + if old.lower().startswith('sqlite_') or old in custom_handled_old_tables: + continue + + new = map_token_no_cascade(old) + if new not in new_tables: + candidates = [t for t in new_tables if t == new or t.startswith(new) or new in t] + if len(candidates) == 1: + new = candidates[0] + else: + continue + + old_cols = [c[1] for c in get_table_info(con_old, old)] + if not old_cols: + continue + new_cols = [c[1] for c in get_table_info(con_new, new)] + + selectable_old_cols, insert_new_cols = [], [] + for oc in old_cols: + mapped = map_token_no_cascade(oc) + if mapped == 'seg_frac': + mapped = 'segment_fraction' + if mapped in new_cols: + selectable_old_cols.append(oc) + insert_new_cols.append(mapped) + + if not selectable_old_cols: + continue + + sel_clause = ','.join(selectable_old_cols) + rows = con_old.execute(f'SELECT {sel_clause} FROM {old}').fetchall() + filtered = [r for r in rows if any(v is not None for v in r)] + if not filtered: + continue + + placeholders = ','.join(['?'] * len(insert_new_cols)) + q = f'INSERT OR REPLACE INTO {new} ({",".join(insert_new_cols)}) VALUES ({placeholders})' + con_new.executemany(q, filtered) + print(f'Copied {len(filtered)} rows: {old} -> {new}') + total += len(filtered) + + # 3. Custom specific logics + print('--- Processing custom migration logic ---') + + # 3.1 LoanLifetimeTech -> loan_lifetime_process + try: + data = con_old.execute( + 'SELECT region, tech, lifetime, notes FROM LoanLifetimeTech' + ).fetchall() + if data: + new_data = [] + for row in data: + vints = [ + v[0] + for v in con_old.execute( + f'SELECT vintage FROM Efficiency WHERE region="{row[0]}" AND tech="{row[1]}"' + ).fetchall() + ] + for v in vints: + new_data.append((row[0], row[1], v, row[2], row[3])) + con_new.executemany( + 'INSERT OR REPLACE INTO loan_lifetime_process (region, tech, vintage, lifetime, notes) VALUES (?,?,?,?,?)', + new_data, + ) + print(f'Migrated {len(new_data)} rows: LoanLifetimeTech -> loan_lifetime_process') + total += len(new_data) + except sqlite3.OperationalError: + pass + + # 3.2 time_season (aggregate from TimeSegmentFraction) + try: + old_data = [] + cols = [c[1] for c in get_table_info(con_old, 'TimeSegmentFraction')] + if 'period' in cols: + old_data = con_old.execute( + 'SELECT season, SUM(segfrac) / COUNT(DISTINCT period) FROM TimeSegmentFraction GROUP BY season' + ).fetchall() + else: + old_data = con_old.execute( + 'SELECT season, SUM(segfrac) FROM TimeSegmentFraction GROUP BY season' + ).fetchall() + + if old_data: + con_new.executemany( + 'INSERT OR REPLACE INTO time_season (season, segment_fraction) VALUES (?, ?)', + old_data, + ) + print(f'Propagated {len(old_data)} seasons to time_season.') + total += len(old_data) + except sqlite3.OperationalError: + pass + + # 3.3 time_of_day (aggregate from TimeSegmentFraction) + try: + old_data = [] + cols = [c[1] for c in get_table_info(con_old, 'TimeSegmentFraction')] + if 'period' in cols: + old_data = con_old.execute( + 'SELECT tod, SUM(segfrac) FROM TimeSegmentFraction GROUP BY tod' + ).fetchall() + if old_data: + num_periods = ( + con_old.execute( + 'SELECT COUNT(DISTINCT period) FROM TimeSegmentFraction' + ).fetchone()[0] + or 1 + ) + normalized_data = [(r[0], (r[1] / num_periods) * 24.0) for r in old_data] + con_new.executemany( + 'INSERT OR REPLACE INTO time_of_day (tod, hours) VALUES (?, ?)', normalized_data + ) + print(f'Propagated {len(normalized_data)} time-of-day slots to time_of_day.') + total += len(normalized_data) + else: + old_data = con_old.execute( + 'SELECT tod, SUM(segfrac) FROM TimeSegmentFraction GROUP BY tod' + ).fetchall() + if old_data: + normalized_data = [(r[0], r[1] * 24.0) for r in old_data] + con_new.executemany( + 'INSERT OR REPLACE INTO time_of_day (tod, hours) VALUES (?, ?)', normalized_data + ) + print(f'Propagated {len(normalized_data)} time-of-day slots to time_of_day.') + total += len(normalized_data) + except sqlite3.OperationalError: + pass + + # 3.4 time_season_sequential + try: + old_data = [] + cols = [c[1] for c in get_table_info(con_old, 'TimeSeasonSequential')] + if 'period' in cols: + first_period = con_old.execute( + 'SELECT MIN(period) FROM TimeSeasonSequential' + ).fetchone()[0] + if first_period: + old_data = con_old.execute( + 'SELECT seas_seq, season, (num_days / 365.25) FROM TimeSeasonSequential WHERE period = ?', + (first_period,), + ).fetchall() + else: + old_data = con_old.execute( + 'SELECT seas_seq, season, (num_days / 365.25) FROM TimeSeasonSequential' + ).fetchall() + + if old_data: + con_new.executemany( + 'INSERT OR REPLACE INTO time_season_sequential ' + '(seas_seq, season, segment_fraction) VALUES (?, ?, ?)', + old_data, + ) + print(f'Propagated {len(old_data)} sequential seasons to time_season_sequential.') + total += len(old_data) + except sqlite3.OperationalError: + pass + + # 3.5 CapacityFactorProcess + try: + old_data = [] + cols = [c[1] for c in get_table_info(con_old, 'CapacityFactorProcess')] + if cols: + if 'period' in cols: + old_data = con_old.execute( + 'SELECT region, season, tod, tech, vintage, AVG(factor) FROM CapacityFactorProcess ' + 'GROUP BY region, season, tod, tech, vintage' + ).fetchall() + else: + old_data = con_old.execute( + 'SELECT region, season, tod, tech, vintage, factor FROM CapacityFactorProcess' + ).fetchall() + if old_data: + con_new.executemany( + 'INSERT OR REPLACE INTO capacity_factor_process (region, season, tod, tech, vintage, factor) VALUES (?,?,?,?,?,?)', + old_data, + ) + print( + f'Copied {len(old_data)} rows: CapacityFactorProcess -> capacity_factor_process' + ) + total += len(old_data) + except sqlite3.OperationalError: + pass + + # 4. Final Updates + con_new.execute("UPDATE technology SET flag='p' WHERE flag='r';") + con_new.execute("INSERT OR REPLACE INTO metadata VALUES ('DB_MAJOR', 4, '')") + con_new.execute("INSERT OR REPLACE INTO metadata VALUES ('DB_MINOR', 0, '')") + print(f'Total rows successfully copied: {total}') + + +def migrate_database(source_path: Path, schema_path: Path, output_path: Path) -> None: + con_old = sqlite3.connect(source_path) + con_new = sqlite3.connect(output_path) + with open(schema_path, encoding='utf-8') as f: + con_new.executescript(f.read()) + + con_new.execute('PRAGMA foreign_keys = 0;') + execute_v3_to_v4_migration(con_old, con_new) + + con_new.commit() + con_new.execute('VACUUM;') + con_new.execute('PRAGMA foreign_keys = 1;') + con_old.close() + con_new.close() + + +def migrate_sql_dump(source_path: Path, schema_path: Path, output_path: Path) -> None: + con_old_in_memory = sqlite3.connect(':memory:') + with open(source_path, encoding='utf-8') as f: + con_old_in_memory.executescript(f.read()) + + con_new_in_memory = sqlite3.connect(':memory:') + with open(schema_path, encoding='utf-8') as f: + con_new_in_memory.executescript(f.read()) + + con_new_in_memory.execute('PRAGMA foreign_keys = 0;') + execute_v3_to_v4_migration(con_old_in_memory, con_new_in_memory) + + con_new_in_memory.commit() + con_new_in_memory.execute('PRAGMA foreign_keys = 1;') + + with open(output_path, 'w', encoding='utf-8') as f_out: + for line in con_new_in_memory.iterdump(): + f_out.write(line + '\n') + + con_old_in_memory.close() + con_new_in_memory.close() + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Master Migrator for v3 to v4') + parser.add_argument('--input', '-i', required=True, help='Input DB or SQL file') + parser.add_argument('--schema', '-s', required=True, help='Path to v4 schema SQL') + parser.add_argument('--output', '-o', required=True, help='Output DB or SQL file') + parser.add_argument('--type', choices=['db', 'sql'], required=True, help='Migration type') + args = parser.parse_args() + + input_path = Path(args.input) + schema_path = Path(args.schema) + output_path = Path(args.output) + + if args.type == 'db': + migrate_database(input_path, schema_path, output_path) + elif args.type == 'sql': + migrate_sql_dump(input_path, schema_path, output_path) diff --git a/temoa/utilities/run_all_v4_migrations.py b/temoa/utilities/run_all_v4_migrations.py index dbac5920..a29ade1e 100644 --- a/temoa/utilities/run_all_v4_migrations.py +++ b/temoa/utilities/run_all_v4_migrations.py @@ -25,7 +25,7 @@ def run_command( cmd: list[str], cwd: Path | None = None, capture_output: bool = True -) -> subprocess.CompletedProcess: +) -> subprocess.CompletedProcess[str]: """Helper to run shell commands.""" print(f'Executing: {" ".join(cmd)}') result = subprocess.run(cmd, cwd=cwd, capture_output=capture_output, text=True, check=False) @@ -36,110 +36,84 @@ def run_command( return result -def main() -> None: - parser = argparse.ArgumentParser( - description='Run SQL migration on all .sql files in a directory, overwriting originals.' - ) - parser.add_argument( - '--input_dir', - required=True, - type=Path, - help='Directory containing the .sql files to migrate.', - ) - parser.add_argument( - '--migration_script', - required=True, - type=Path, - help='Path to the sql_migration_v_3_1_to_v4.py script.', - ) - parser.add_argument( - '--v4_schema_path', - required=True, - type=Path, - help='Path to the canonical v4 schema SQL file (temoa_schema_v4.sql).', - ) - parser.add_argument( - '--dry_run', - action='store_true', - help='Perform a dry run: show which files would be processed, but do not modify.', - ) - - args = parser.parse_args() - - input_dir = args.input_dir.resolve() - migration_script = args.migration_script.resolve() - v4_schema_path = args.v4_schema_path.resolve() - +def run_migrations( + input_dir: Path, migration_script: Path, schema_path: Path, dry_run: bool = False +) -> None: if not input_dir.is_dir(): print(f'Error: Input directory not found at {input_dir}') sys.exit(1) if not migration_script.is_file(): print(f'Error: Migration script not found at {migration_script}') sys.exit(1) - if not v4_schema_path.is_file(): - print(f'Error: V4 schema file not found at {v4_schema_path}') + if not schema_path.is_file(): + print(f'Error: schema file not found at {schema_path}') sys.exit(1) - print(f'Scanning for .sql files in: {input_dir}') + print(f'Scanning for .sql and .sqlite files in: {input_dir}') sql_files = list(input_dir.glob('*.sql')) + db_files = list(input_dir.glob('*.sqlite')) + list(input_dir.glob('*.db')) + all_files = sql_files + db_files - if not sql_files: - print(f'No .sql files found in {input_dir}. Exiting.') - sys.exit(0) + if not all_files: + print(f'No .sql, .sqlite, or .db files found in {input_dir}. Exiting.') + return - if args.dry_run: + if dry_run: print('\n--- Dry Run ---') - print(f'The following {len(sql_files)} .sql files would be processed:') - for f in sql_files: + print(f'The following {len(all_files)} files would be processed:') + for f in all_files: print(f' - {f.name}') print('\nNo files will be modified in dry run mode.') - sys.exit(0) + return - print(f'\n--- Starting Migration of {len(sql_files)} files ---') + print(f'\n--- Starting Migration of {len(all_files)} files ---') processed_count = 0 failed_files = [] - for sql_file in sql_files: - print(f'\nProcessing: {sql_file.name}') + for target_file in all_files: + print(f'\nProcessing: {target_file.name}') - temp_output_file = Path(tempfile.mkstemp(suffix='.sql', prefix='temp_migrated_')[1]) + ext = target_file.suffix.lower() + temp_output_file = Path(tempfile.mkstemp(suffix=ext, prefix='temp_migrated_')[1]) original_backup_file = Path(tempfile.mkstemp(suffix='.bak', prefix='orig_backup_')[1]) + mig_type = 'sql' if ext == '.sql' else 'db' try: - # 1. Back up original file (to restore on failure) - shutil.copy2(sql_file, original_backup_file) + # 1. Back up original file + shutil.copy2(target_file, original_backup_file) # 2. Run migration script, outputting to a temporary file migration_cmd = [ 'python3', str(migration_script), '--input', - str(sql_file), + str(target_file), '--schema', - str(v4_schema_path), + str(schema_path), '--output', str(temp_output_file), + '--type', + mig_type, ] result = run_command(migration_cmd, cwd=Path.cwd()) if result.returncode == 0: - # 3. If successful, overwrite original file with converted content - shutil.copy2(temp_output_file, sql_file) - print(f'SUCCESS: {sql_file.name} migrated and overwritten.') + # 3. If successful, overwrite original file + shutil.copy2(temp_output_file, target_file) + print(f'SUCCESS: {target_file.name} migrated and overwritten.') processed_count += 1 else: # 4. On failure, restore original file - print(f'FAILED: Migration for {sql_file.name} failed. Restoring original file.') - shutil.copy2(original_backup_file, sql_file) - failed_files.append(sql_file.name) + print(f'FAILED: Migration for {target_file.name} failed. Restoring original file.') + shutil.copy2(original_backup_file, target_file) + failed_files.append(target_file.name) except Exception as e: - print(f'CRITICAL ERROR processing {sql_file.name}: {e}. Restoring original file.') + print(f'CRITICAL ERROR processing {target_file.name}: {e}. Restoring original file.') if original_backup_file.exists(): - shutil.copy2(original_backup_file, sql_file) - failed_files.append(sql_file.name) + shutil.copy2(original_backup_file, target_file) + failed_files.append(target_file.name) finally: - # Clean up temporary files if temp_output_file.exists(): os.remove(temp_output_file) if original_backup_file.exists(): @@ -152,7 +126,44 @@ def main() -> None: sys.exit(1) else: print('All files migrated successfully.') - sys.exit(0) + + +def main() -> None: + parser = argparse.ArgumentParser( + description='Run script migration on all .sql/.sqlite/.db files in a directory, overwriting originals.' + ) + parser.add_argument( + '--input_dir', + required=True, + type=Path, + help='Directory containing the files to migrate.', + ) + parser.add_argument( + '--migration_script', + required=True, + type=Path, + help='Path to the master_migration.py script.', + ) + parser.add_argument( + '--v4_schema_path', + required=True, + type=Path, + help='Path to the canonical v4 schema SQL file.', + ) + parser.add_argument( + '--dry_run', + action='store_true', + help='Perform a dry run: show which files would be processed, but do not modify.', + ) + + args = parser.parse_args() + + run_migrations( + input_dir=args.input_dir.resolve(), + migration_script=args.migration_script.resolve(), + schema_path=args.v4_schema_path.resolve(), + dry_run=args.dry_run, + ) if __name__ == '__main__': diff --git a/tests/test_cli.py b/tests/test_cli.py index c0bd0a28..b42ca58f 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -153,7 +153,7 @@ def test_cli_migrate_help() -> None: result = runner.invoke(app, ['migrate', '--help']) assert result.exit_code == 0 assert 'migrate' in result.stdout - assert 'Migrate a single Temoa database file' in result.stdout + assert 'Migrate a Temoa database file' in result.stdout def test_cli_migrate_sql_file(tmp_path: Path) -> None: @@ -172,19 +172,20 @@ def test_cli_migrate_sql_file(tmp_path: Path) -> None: assert output_file.exists() -def test_cli_migrate_rejects_directory_input(tmp_path: Path) -> None: - """Test that the migrate command rejects a directory as input.""" +def test_cli_migrate_accepts_directory_input(tmp_path: Path) -> None: + """Test that the migrate command accepts a directory as input for batch processing.""" dummy_dir = tmp_path / 'my_dummy_dir' dummy_dir.mkdir() args = ['migrate', str(dummy_dir)] result = runner.invoke(app, args) - assert result.exit_code != 0 + assert result.exit_code == 0 # Normalize whitespace to handle platform-specific line breaks from rich.print() normalized_output = ' '.join(result.stdout.split()) - assert 'Error: Input path must be a file, not a directory:' in normalized_output + assert 'Batch migrating directory' in normalized_output # Check for the directory name in the original output (paths may be split across lines) assert 'my_dummy_dir' in result.stdout + assert 'No .sql, .sqlite, or .db files found' in result.stdout def test_cli_migrate_sql_file_auto_output_writable_input_dir(tmp_path: Path) -> None: @@ -250,10 +251,10 @@ def mock_is_writable(path: Path) -> bool: normalized_output = ' '.join(result.stdout.split()) assert 'SQL dump migration completed' in normalized_output assert 'Warning: Input directory' in normalized_output - assert str(non_writable_mock_parent) in normalized_output - assert 'is not writable.' in normalized_output - assert 'Saving output to current directory:' in normalized_output - assert str(tmp_path) in normalized_output + assert 'mock_non_writable_input_parent' in re.sub(r'\s+', '', result.stdout) + assert 'is not writable' in normalized_output + assert 'Saving output to current directory' in normalized_output + assert tmp_path.name in normalized_output expected_output_in_cwd = tmp_path / (input_file.stem + '_v4.sql') assert expected_output_in_cwd.exists() diff --git a/tests/test_v4_migration.py b/tests/test_v4_migration.py index 644c6fff..534ee211 100644 --- a/tests/test_v4_migration.py +++ b/tests/test_v4_migration.py @@ -14,9 +14,9 @@ def test_v4_migrations(tmp_path: Path) -> None: - """Test both SQL and SQLite v4 migrators.""" + """Test both SQL and SQLite master migrators.""" - # 1. Create v3.1 SQLite DB + # 1. Create v3.1 SQLite DB (acts as our v3 mock since the mock schema is 3.1) db_v3_1 = tmp_path / 'test_v3_1.sqlite' with sqlite3.connect(db_v3_1) as conn: conn.execute('PRAGMA foreign_keys = OFF') @@ -24,7 +24,7 @@ def test_v4_migrations(tmp_path: Path) -> None: conn.executescript(MOCK_DATA.read_text()) conn.execute('PRAGMA foreign_keys = ON') - # 2. Dump v3.1 to SQL + # 2. Dump to SQL sql_v3_1 = tmp_path / 'test_v3_1.sql' with open(sql_v3_1, 'w') as f: for line in sqlite3.connect(db_v3_1).iterdump(): @@ -35,7 +35,9 @@ def test_v4_migrations(tmp_path: Path) -> None: subprocess.run( [ sys.executable, - str(UTILITIES_DIR / 'sql_migration_v3_1_to_v4.py'), + str(UTILITIES_DIR / 'master_migration.py'), + '--type', + 'sql', '--input', str(sql_v3_1), '--schema', @@ -57,12 +59,14 @@ def test_v4_migrations(tmp_path: Path) -> None: subprocess.run( [ sys.executable, - str(UTILITIES_DIR / 'db_migration_v3_1_to_v4.py'), - '--source', + str(UTILITIES_DIR / 'master_migration.py'), + '--type', + 'db', + '--input', str(db_v3_1), '--schema', str(SCHEMA_V4), - '--out', + '--output', str(db_v4_migrated), ], check=True, From 6014de7ee5b1c0131327b9f5fe8c391a629656e1 Mon Sep 17 00:00:00 2001 From: ParticularlyPythonicBS Date: Wed, 25 Mar 2026 11:56:40 -0400 Subject: [PATCH 2/5] Fix Windows CI test failure caused by rich word wrapping in tests/test_cli.py --- tests/test_cli.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/test_cli.py b/tests/test_cli.py index b42ca58f..7bb8468f 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -249,12 +249,14 @@ def mock_is_writable(path: Path) -> bool: ) # Normalize whitespace to handle platform-specific line breaks from rich.print() normalized_output = ' '.join(result.stdout.split()) + clean_output = re.sub(r'\s+', '', result.stdout) + assert 'SQL dump migration completed' in normalized_output assert 'Warning: Input directory' in normalized_output - assert 'mock_non_writable_input_parent' in re.sub(r'\s+', '', result.stdout) + assert 'mock_non_writable_input_parent' in clean_output assert 'is not writable' in normalized_output assert 'Saving output to current directory' in normalized_output - assert tmp_path.name in normalized_output + assert tmp_path.name in clean_output expected_output_in_cwd = tmp_path / (input_file.stem + '_v4.sql') assert expected_output_in_cwd.exists() From 23fce8d341f6c82cff4e88737c75bf36ad0d47cd Mon Sep 17 00:00:00 2001 From: ParticularlyPythonicBS Date: Wed, 25 Mar 2026 12:39:25 -0400 Subject: [PATCH 3/5] Finalize migration refactor: address code review, purge legacy scripts, and harden testing --- temoa/cli.py | 1 + temoa/utilities/db_migration_v3_1_to_v4.py | 289 ------------- temoa/utilities/db_migration_v3_to_v3_1.py | 452 -------------------- temoa/utilities/master_migration.py | 39 +- temoa/utilities/run_all_v4_migrations.py | 66 +-- temoa/utilities/sql_migration_v3_1_to_v4.py | 391 ----------------- tests/test_cli.py | 7 +- tests/test_v4_migration.py | 26 +- tests/testing_data/migration_v3_mock.sql | 30 ++ 9 files changed, 122 insertions(+), 1179 deletions(-) delete mode 100644 temoa/utilities/db_migration_v3_1_to_v4.py delete mode 100644 temoa/utilities/db_migration_v3_to_v3_1.py delete mode 100644 temoa/utilities/sql_migration_v3_1_to_v4.py create mode 100644 tests/testing_data/migration_v3_mock.sql diff --git a/temoa/cli.py b/temoa/cli.py index b6766774..66b1d498 100644 --- a/temoa/cli.py +++ b/temoa/cli.py @@ -456,6 +456,7 @@ def migrate( migration_script=migration_script, schema_path=schema_path, dry_run=False, + silent=silent, ) if not silent: rich.print(f'[green]Directory migration completed for {input_path}[/green]') diff --git a/temoa/utilities/db_migration_v3_1_to_v4.py b/temoa/utilities/db_migration_v3_1_to_v4.py deleted file mode 100644 index be294bb7..00000000 --- a/temoa/utilities/db_migration_v3_1_to_v4.py +++ /dev/null @@ -1,289 +0,0 @@ -#!/usr/bin/env python3 -""" -db_migration_v3_1_to_v4.py - -Migrate a v3.1 SQLite DB to a v4 SQLite DB using deterministic mapping rules. - -Usage: - python db_migration_v3_1_to_v4.py --source old_v3_1.sqlite \ - --schema temoa_schema_v4.sql \ - --out new_v4.sqlite -""" - -from __future__ import annotations - -import argparse -import re -import sqlite3 -from pathlib import Path - -# ---------- Mapping configuration ---------- -CUSTOM_MAP: dict[str, str] = { - 'TimeSeason': 'time_season', - 'time_season': 'time_season', - 'TimeSeasonSequential': 'time_season_sequential', - 'time_season_sequential': 'time_season_sequential', - 'TimeNext': 'time_manual', - 'CommodityDStreamProcess': 'commodity_down_stream_process', - 'commodityUStreamProcess': 'commodity_up_stream_process', - 'SegFrac': 'segment_fraction', - 'segfrac': 'segment_fraction', - 'MetaDataReal': 'metadata_real', - 'MetaData': 'metadata', - 'Myopicefficiency': 'myopic_efficiency', - 'DB_MAJOR': 'db_major', - 'DB_MINOR': 'db_minor', -} -CUSTOM_EXACT_ONLY = {'time_season', 'time_season_sequential'} -CUSTOM_KEYS_SORTED = sorted( - [k for k in CUSTOM_MAP.keys() if k not in CUSTOM_EXACT_ONLY], key=lambda k: -len(k) -) - - -def to_snake_case(s: str) -> str: - if not s: - return s - if s == s.lower() and '_' in s: - return s - x = s.replace('-', '_').replace(' ', '_') - x = re.sub(r'(.)([A-Z][a-z]+)', r'\1_\2', x) - x = re.sub(r'([a-z0-9])([A-Z])', r'\1_\2', x) - x = re.sub(r'__+', '_', x) - return x.lower() - - -def map_token_no_cascade(token: str) -> str: - if not token: - return token - mapped_values = {v.lower() for v in CUSTOM_MAP.values()} - if token.lower() in mapped_values: - return token.lower() - if token in CUSTOM_MAP: - return CUSTOM_MAP[token].lower() - tl = token.lower() - for k, v in CUSTOM_MAP.items(): - if tl == k.lower(): - return v.lower() - if any(c.isupper() for c in token): - return to_snake_case(token) - orig = token - orig_lower = orig.lower() - replacements: list[tuple[str, str]] = [ - (k, CUSTOM_MAP[k]) for k in CUSTOM_KEYS_SORTED if k.lower() in orig_lower - ] - if replacements: - out = [] - i = 0 - length = len(orig) - while i < length: - matched = False - for key, repl in replacements: - kl = len(key) - if i + kl <= length and orig[i : i + kl].lower() == key.lower(): - out.append(repl) - i += kl - matched = True - break - if not matched: - out.append(orig[i]) - i += 1 - mapped_once = ''.join(out) - mapped_once = re.sub(r'__+', '_', mapped_once).lower() - return mapped_once - return to_snake_case(token) - - -def get_table_info(conn: sqlite3.Connection, table: str) -> list[tuple]: - try: - return conn.execute(f'PRAGMA table_info({table});').fetchall() - except sqlite3.OperationalError: - return [] - - -def migrate_direct_table( - con_old: sqlite3.Connection, con_new: sqlite3.Connection, old_table: str, new_table: str -) -> int: - old_cols = [c[1] for c in get_table_info(con_old, old_table)] - if not old_cols: - return 0 - new_cols = [c[1] for c in get_table_info(con_new, new_table)] - selectable_old_cols, insert_new_cols = [], [] - for oc in old_cols: - mapped = map_token_no_cascade(oc) - if mapped == 'seg_frac': - mapped = 'segment_fraction' - if mapped in new_cols: - selectable_old_cols.append(oc) - insert_new_cols.append(mapped) - if not selectable_old_cols: - return 0 - sel_clause = ','.join(selectable_old_cols) - rows = con_old.execute(f'SELECT {sel_clause} FROM {old_table}').fetchall() - if not rows: - return 0 - # filter out rows that are entirely NULL - filtered = [r for r in rows if any(v is not None for v in r)] - if not filtered: - return 0 - placeholders = ','.join(['?'] * len(insert_new_cols)) - q = f'INSERT OR REPLACE INTO {new_table} ({",".join(insert_new_cols)}) VALUES ({placeholders})' - con_new.executemany(q, filtered) - return len(filtered) - - -def migrate_all(args) -> None: - src = Path(args.source) - schema = Path(args.schema) - out = Path(args.out) if args.out else src.with_suffix('.v4.sqlite') - con_old = sqlite3.connect(src) - con_new = sqlite3.connect(out) - with open(schema, encoding='utf-8') as f: - sql = f.read() - con_new.executescript(sql) - con_new.execute('PRAGMA foreign_keys = 0;') - old_tables = [ - r[0] - for r in con_old.execute("SELECT name FROM sqlite_master WHERE type='table'").fetchall() - ] - new_tables = [ - r[0] - for r in con_new.execute("SELECT name FROM sqlite_master WHERE type='table'").fetchall() - ] - total = 0 - for old in old_tables: - if old.lower().startswith('sqlite_'): - continue - if old in ( - 'MetaData', - 'MetaDataReal', - 'TimeSeason', - 'time_season', - 'time_of_day', - 'time_season_sequential', - 'TimeSeasonSequential', - ): - continue - if old == 'CapacityFactorProcess': - # Special case: aggregate across periods - old_data = con_old.execute( - 'SELECT region, season, tod, tech, vintage, AVG(factor) ' - 'FROM CapacityFactorProcess ' - 'GROUP BY region, season, tod, tech, vintage' - ).fetchall() - if old_data: - con_new.executemany( - 'INSERT OR REPLACE INTO capacity_factor_process ' - '(region, season, tod, tech, vintage, factor) ' - 'VALUES (?, ?, ?, ?, ?, ?)', - old_data, - ) - print(f'Aggregated {len(old_data)} rows: {old} -> capacity_factor_process') - total += len(old_data) - continue - - new = map_token_no_cascade(old) - if new not in new_tables: - candidates = [t for t in new_tables if t == new or t.startswith(new) or new in t] - if len(candidates) == 1: - new = candidates[0] - else: - print(f'SKIP (no target): {old} -> {new}; candidates={candidates}') - continue - try: - n = migrate_direct_table(con_old, con_new, old, new) - print(f'Copied {n} rows: {old} -> {new}') - total += n - except Exception: - print(f'Error migrating {old} -> {new}') - raise - - # --- Custom logic for restructured tables (Seasons/TOD) --- - print('Processing custom migration logic for seasons and time-of-day...') - - # 1. time_season (aggregate from TimeSegmentFraction) - try: - old_data = con_old.execute( - 'SELECT season, SUM(segfrac) / COUNT(DISTINCT period) ' - 'FROM TimeSegmentFraction GROUP BY season' - ).fetchall() - if old_data: - con_new.executemany( - 'INSERT OR REPLACE INTO time_season (season, segment_fraction) VALUES (?, ?)', - old_data, - ) - print(f'Propagated {len(old_data)} seasons to time_season.') - total += len(old_data) - except sqlite3.OperationalError: - print('WARNING: Could not migrate seasons from TimeSegmentFraction.') - - # 2. time_of_day (aggregate from TimeSegmentFraction) - try: - old_data = con_old.execute( - 'SELECT tod, SUM(segfrac) FROM TimeSegmentFraction GROUP BY tod' - ).fetchall() - if old_data: - num_periods = ( - con_old.execute( - 'SELECT COUNT(DISTINCT period) FROM TimeSegmentFraction' - ).fetchone()[0] - or 1 - ) - normalized_data = [(r[0], (r[1] / num_periods) * 24.0) for r in old_data] - con_new.executemany( - 'INSERT OR REPLACE INTO time_of_day (tod, hours) VALUES (?, ?)', normalized_data - ) - print(f'Propagated {len(normalized_data)} time-of-day slots to time_of_day.') - total += len(normalized_data) - except sqlite3.OperationalError: - print('WARNING: Could not migrate time_of_day from TimeSegmentFraction.') - - # 3. time_season_sequential (aggregate from TimeSeasonSequential) - try: - first_period = con_old.execute('SELECT MIN(period) FROM TimeSeasonSequential').fetchone()[0] - if first_period: - old_data = con_old.execute( - 'SELECT tss.seas_seq, tss.season, (tss.num_days / 365.25) ' - 'FROM TimeSeasonSequential tss ' - 'WHERE tss.period = ?', - (first_period,), - ).fetchall() - if old_data: - con_new.executemany( - 'INSERT OR REPLACE INTO time_season_sequential ' - '(seas_seq, season, segment_fraction) VALUES (?, ?, ?)', - old_data, - ) - print(f'Propagated {len(old_data)} sequential seasons to time_season_sequential.') - total += len(old_data) - except (sqlite3.OperationalError, TypeError): - pass - - # ensure metadata version bumped - cur = con_new.cursor() - cur.execute("INSERT OR REPLACE INTO metadata VALUES ('DB_MAJOR', 4, '')") - cur.execute("INSERT OR REPLACE INTO metadata VALUES ('DB_MINOR', 0, '')") - con_new.commit() - con_new.execute('VACUUM;') - con_new.execute('PRAGMA foreign_keys = 1;') - try: - fk = con_new.execute('PRAGMA FOREIGN_KEY_CHECK;').fetchall() - if fk: - print('FK issues:', fk) - except sqlite3.OperationalError: - pass - con_old.close() - con_new.close() - print('Done; approx rows:', total, '->', out) - - -def parse_cli() -> argparse.Namespace: - p = argparse.ArgumentParser() - p.add_argument('--source', required=True) - p.add_argument('--schema', required=True) - p.add_argument('--out', required=False) - return p.parse_args() - - -if __name__ == '__main__': - args = parse_cli() - migrate_all(args) diff --git a/temoa/utilities/db_migration_v3_to_v3_1.py b/temoa/utilities/db_migration_v3_to_v3_1.py deleted file mode 100644 index e5e5180e..00000000 --- a/temoa/utilities/db_migration_v3_to_v3_1.py +++ /dev/null @@ -1,452 +0,0 @@ -""" -Transition a v3.0 database to a v3.1 database. -""" - -import argparse -import os -import sqlite3 -import sys -from pathlib import Path - -import pandas as pd - -from temoa.core.model import TemoaModel - -# Just to get the default lifetime... -this_dir = os.path.dirname(__file__) -root_dir = os.path.abspath(os.path.join(this_dir, '../..')) -sys.path.append(root_dir) - - -parser = argparse.ArgumentParser() -parser.add_argument( - '--source', - help='Path to original database', - required=True, - action='store', - dest='source_db', -) -parser.add_argument( - '--schema', - help='Path to schema file (default=data_files/temoa_schema_v3_1)', - required=False, - dest='schema', - default='data_files/temoa_schema_v3_1.sql', -) -options = parser.parse_args() -legacy_db: Path = Path(options.source_db) -schema_file = Path(options.schema) - -new_db_name = legacy_db.stem + '_v3_1.sqlite' -new_db_path = Path(legacy_db.parent, new_db_name) - -con_old = sqlite3.connect(legacy_db) -con_new = sqlite3.connect(new_db_path) -cur = con_new.cursor() - -# bring in the new schema and execute -with open(schema_file) as src: - sql_script = src.read() -con_new.executescript(sql_script) - -# turn off FK verification while process executes -con_new.execute('PRAGMA foreign_keys = 0;') - - -def column_check(old_name: str, new_name: str) -> bool: - if old_name == '': - old_name = new_name - - try: - con_old.execute(f'SELECT * FROM {old_name}').fetchone() - except sqlite3.OperationalError: - return True - - new_columns = [c[1] for c in con_new.execute(f'PRAGMA table_info({new_name});').fetchall()] - old_columns = [c[1] for c in con_old.execute(f'PRAGMA table_info({old_name});').fetchall()] - - missing = [c for c in new_columns if c not in old_columns and c not in ('period', 'notes')] - if len(missing) > 0: - msg = ( - f'Columns of {new_name} in the new database missing from {old_name} in old database. ' - 'Try adding or renaming the column in the old database:' - f'\n{missing}\n' - ) - print(msg) - return False - return True - - -# table mapping for DIRECT transfers -# fmt: off -direct_transfer_tables = [ - ("", "CapacityCredit"), - ("", "CapacityToActivity"), - ("", "Commodity"), - ("", "CommodityType"), - ("", "CostEmission"), - ("", "CostFixed"), - ("", "CostInvest"), - ("", "CostVariable"), - ("", "Demand"), - ("", "Efficiency"), - ("", "EmissionActivity"), - ("", "ExistingCapacity"), - ("", "LifetimeProcess"), - ("", "LifetimeTech"), - ("", "LinkedTech"), - ("", "LoanRate"), - ("", "MetaData"), - ("", "MetaDataReal"), - ("", "PlanningReserveMargin"), - ("RampDown", "RampDownHourly"), - ("RampUp", "RampUpHourly"), - ("", "Region"), - ("", "RPSRequirement"), - ("", "SectorLabel"), - ("", "StorageDuration"), - ("", "TechGroup"), - ("", "TechGroupMember"), - ("", "Technology"), - ("", "TechnologyType"), - ("", "TimeOfDay"), - ("", "TimePeriod"), - ("", "TimePeriodType"), -] - -period_added_tables = [ - ("", "CapacityFactorProcess"), - ("", "CapacityFactorTech"), - ("", "DemandSpecificDistribution"), - ("", "TimeSeason"), - ("", "TimeSegmentFraction"), -] - -period_to_vintage_tables = { - "LimitNewCapacityShare", - "LimitNewCapacity", -} - -operator_added_tables = { - "EmissionLimit": ("LimitEmission", "le"), - "TechOutputSplit": ("LimitTechOutputSplit", "ge"), - "TechInputSplitAnnual": ("LimitTechInputSplitAnnual", "ge"), - "TechInputSplitAverage": ("LimitTechInputSplitAnnual", "ge"), - "TechInputSplit": ("LimitTechInputSplit", "ge"), - "MinNewCapacityShare": ("LimitNewCapacityShare", "ge"), - "MinNewCapacityGroupShare": ("LimitNewCapacityShare", "ge"), - "MinNewCapacityGroup": ("LimitNewCapacity", "ge"), - "MinNewCapacity": ("LimitNewCapacity", "ge"), - "MinCapacityShare": ("LimitCapacityShare", "ge"), - "MinCapacityGroup": ("LimitCapacity", "ge"), - "MinCapacity": ("LimitCapacity", "ge"), - "MinActivityShare": ("LimitActivityShare", "ge"), - "MinActivityGroup": ("LimitActivity", "ge"), - "MinActivity": ("LimitActivity", "ge"), - "MaxNewCapacityShare": ("LimitNewCapacityShare", "le"), - "MaxNewCapacityGroupShare": ("LimitNewCapacityShare", "le"), - "MaxNewCapacityGroup": ("LimitNewCapacity", "le"), - "MaxNewCapacity": ("LimitNewCapacity", "le"), - "MaxCapacityShare": ("LimitCapacityShare", "le"), - "MaxCapacityGroup": ("LimitCapacity", "le"), - "MaxCapacity": ("LimitCapacity", "le"), - "MaxActivityShare": ("LimitActivityShare", "le"), - "MaxActivityGroup": ("LimitActivity", "le"), - "MaxActivity": ("LimitActivity", "le"), - "MaxResource": ("LimitResource", "le"), -} - -no_transfer = { - "MinSeasonalActivity": "LimitSeasonalCapacityFactor", - "MaxSeasonalActivity": "LimitSeasonalCapacityFactor", - "MinAnnualCapacityFactor": "LimitAnnualCapacityFactor", - "MaxAnnualCapacityFactor": "LimitAnnualCapacityFactor", - "StorageInit": "LimitStorageLevelFraction", -} - - -all_good = True -for old_name, new_name in direct_transfer_tables: - all_good = all_good and column_check(old_name, new_name) -for old_name, new_name in period_added_tables: - all_good = all_good and column_check(old_name, new_name) -if not all_good: - sys.exit(1) - - -# Collapse Max/Min constraint tables -print("\n --- Collapsing Max/Min tables and adding operators ---") -for old_name, (new_name, operator) in operator_added_tables.items(): - - try: - data = con_old.execute(f"SELECT * FROM {old_name}").fetchall() - except sqlite3.OperationalError: - print("TABLE NOT FOUND: " + old_name) - continue - - if not data: - print("No data for: " + old_name) - continue - - new_cols: list[str] = [ - c[1] for c in con_new.execute(f"PRAGMA table_info({new_name});").fetchall() - ] - op_index = new_cols.index("operator") - data = [(*row[0:op_index], operator, *row[op_index:len(new_cols)-1]) for row in data] - # if table in period_to_vintage_tables, move period value from period column to vintage column - if new_name in period_to_vintage_tables: - old_cols: list[str] = [ - c[1] for c in con_old.execute(f"PRAGMA table_info({old_name});").fetchall() - ] - period_index = old_cols.index("period") - vintage_index = new_cols.index("vintage") - data = [ - ( - *row[0:period_index], - *row[period_index+1:vintage_index+1], - row[period_index], - *row[vintage_index+1:] - ) - for row in data - ] - - # construct the query with correct number of placeholders - num_placeholders = len(data[0]) - placeholders = ",".join(["?" for _ in range(num_placeholders)]) - query = f"INSERT OR REPLACE INTO {new_name} VALUES ({placeholders})" - con_new.executemany(query, data) - print(f"Transfered {len(data)} rows from {old_name} to {new_name}") - -# It wasn't active anyway... can't be bothered -# StorageInit -> LimitStorageLevelFraction - -# execute the direct transfers -print("\n --- Executing direct transfers ---") -for old_name, new_name in direct_transfer_tables: - if old_name == "": - old_name = new_name - - try: - con_old.execute(f"SELECT * FROM {old_name}").fetchone() - except sqlite3.OperationalError: - print("TABLE NOT FOUND: " + old_name) - continue - - old_columns = [c[1] for c in con_old.execute(f"PRAGMA table_info({old_name});").fetchall()] - new_columns = [c[1] for c in con_new.execute(f"PRAGMA table_info({new_name});").fetchall()] - cols = [c for c in new_columns if c in old_columns] - data = con_old.execute(f'SELECT {str(cols)[1:-1].replace("'","")} FROM {old_name}').fetchall() - - if not data: - print("No data for: " + old_name) - continue - - # construct the query with correct number of placeholders - num_placeholders = len(data[0]) - placeholders = ",".join(["?" for _ in range(num_placeholders)]) - query = ( - "INSERT OR REPLACE INTO " - f"{new_name}{tuple(c for c in cols) if len(cols)>1 else f'({cols[0]})'} " - f"VALUES ({placeholders})" - ) - con_new.executemany(query, data) - print(f"Transfered {len(data)} rows from {old_name} to {new_name}") - -time_all = [ - p[0] for p in cur.execute("SELECT period FROM TimePeriod").fetchall() -] -time_all = sorted(time_all)[0:-1] # Exclude horizon end - -# get lifetimes. Major headache but needs to be done -lifetime_process = {} -data = cur.execute("SELECT region, tech, vintage FROM Efficiency").fetchall() -for rtv in data: - lifetime_process[rtv] = TemoaModel.default_lifetime_tech -data = cur.execute("SELECT region, tech, lifetime FROM LifetimeTech").fetchall() -for rtl in data: - for v in time_all: - lifetime_process[*rtl[0:2], v] = rtl[2] -data = cur.execute("SELECT region, tech, vintage, lifetime FROM LifetimeProcess").fetchall() -for rtvl in data: - lifetime_process[rtvl[0:3]] = rtvl[3] - -# Planning periods to add to period indices -time_optimize = [ - p[0] for p in cur.execute('SELECT period FROM TimePeriod WHERE flag == "f"').fetchall() -] -time_optimize = sorted(time_optimize)[0:-1] # Exclude horizon end - -# add period indexing to seasonal tables -print("\n --- Adding period index to some tables ---") -for old_name, new_name in period_added_tables: - if old_name == "": - old_name = new_name - - try: - con_old.execute(f"SELECT * FROM {old_name}").fetchone() - except sqlite3.OperationalError: - print("TABLE NOT FOUND: " + old_name) - continue - - old_columns = [c[1] for c in con_old.execute(f"PRAGMA table_info({old_name});").fetchall()] - new_columns = [c[1] for c in con_new.execute(f"PRAGMA table_info({new_name});").fetchall()] - cols = [c for c in new_columns if c in old_columns] - data = pd.read_sql_query(f'SELECT {str(cols)[1:-1].replace("'","")} FROM {old_name}', con_old) - - if len(data) == 0: - print("No data for: " + old_name) - continue - - # This insanity collects the viable periods for each table - if "vintage" in cols: - data["periods"] = [ - ( - p for p in time_optimize - if v <= p < v+lifetime_process[r, t, v] - ) - for r, t, v in data[["region","tech","vintage"]] - ] - elif "tech" in cols: - periods = {} - for r, t in data[["region","tech"]].drop_duplicates().values: - periods[r, t] = [ - p for p in time_optimize - if any( - v <= p < v+lifetime_process[r, t, v] - for v in [ - t[0] for t in con_old.execute( - f'SELECT vintage FROM Efficiency WHERE region == "{r}" AND ' - f'tech == "{t}"' - ).fetchall() - ] - ) - ] - data["periods"] = [ - periods[r, t] - for (r, t) in data[["region","tech"]].values - ] - else: - data["periods"] = [time_optimize for i in data.index] - - data_new = [] - for p in time_optimize: - for _idx, row in data.iterrows(): - if p not in row["periods"]: - continue - if old_name[0:5] == "TimeS": # horrible but covers TimeSeason and TimeSegmentFraction - data_new.append((p, *row.iloc[0:-1])) - else: - data_new.append((row.iloc[0], p, *row.iloc[1:-1])) - - if old_name[0:5] == "TimeS": # horrible but covers TimeSeason and TimeSegmentFraction - cols = ["period",*cols] - else: - cols = [cols[0],"period",*cols[1::]] - - # construct the query with correct number of placeholders - num_placeholders = len(data_new[0]) - placeholders = ",".join(["?" for _ in range(num_placeholders)]) - query = ( - "INSERT OR REPLACE INTO " - f"{new_name}{tuple(c for c in cols) if len(cols)>1 else f'({cols[0]})'} " - f"VALUES ({placeholders})" - ) - con_new.executemany(query, data_new) - print(f"Transfered {len(data_new)} rows from {old_name} to {new_name}") - - -print("\n --- Making some final changes ---") -n_del = len(con_new.execute( - "SELECT * FROM DemandSpecificDistribution " - "WHERE (region, period, demand_name) " - "NOT IN (SELECT region, period, commodity FROM Demand)" -).fetchall()) -if n_del > 0: - con_new.execute( - "DELETE FROM DemandSpecificDistribution " - "WHERE (region, period, demand_name) " - "NOT IN (SELECT region, period, commodity FROM Demand)" - ) - print( - f"{n_del} extraneous rows removed from DemandSpecificDistribution after adding period index" - ) - -# TimeSeason unique seasons to SeasonLabel -con_new.execute("INSERT OR REPLACE INTO SeasonLabel(season) SELECT DISTINCT season FROM TimeSeason") -print("Filled SeasonLabel") - -# Removal of tech_resource -con_new.execute("UPDATE Technology SET flag='p' WHERE flag=='r';") -print("Converted all resource techs to production techs.") - -# LoanLifetimeTech -> LoanLifetimeProcess -try: - data = con_old.execute("SELECT region, tech, lifetime, notes FROM LoanLifetimeTech").fetchall() -except sqlite3.OperationalError: - print("TABLE NOT FOUND: LoanLifetimeTech") - -if not data: - print("No data for: LoanLifetimeTech") -else: - new_data = [] - for row in data: - vints = [ - v[0] - for v in con_old.execute( - f'SELECT vintage FROM Efficiency WHERE region=="{row[0]}" AND tech="{row[1]}"' - ).fetchall() - ] - for v in vints: - new_data.append((row[0], row[1], v, row[2], row[3])) - query = "INSERT OR REPLACE INTO LoanLifetimeProcess VALUES (?,?,?,?,?)" - con_new.executemany(query, new_data) - print(f"Transfered {len(new_data)} rows from LifetimeLoanTech to LifetimeLoanProcess") - - -# Warn about incompatible changes -print( - "\n --- The following transfers were impossible due to incompatible changes. Transfer " - "manually. ---" -) -for old_name, new_name in no_transfer.items(): - # Check if it exists in the old database. If not, no need to warn about it. - try: - con_old.execute(f"SELECT * FROM {old_name}").fetchone() - except sqlite3.OperationalError: - continue - print(f"{old_name} to {new_name}") - - -print("\n --- Updating MetaData ---") -cur.execute("DELETE FROM MetaData WHERE element == 'myopic_base_year'") -print( - "myopic_base_year removed from MetaData. This parameter is no longer used. " \ - "Costs will discount to the first future period." -) -cur.execute("UPDATE MetaData SET value = 1 WHERE element == 'DB_MINOR'") -print("Updated database version to 3.1") - - - -print("\n --- Validating foreign keys ---") -con_new.commit() -con_new.execute("VACUUM;") -con_new.execute("PRAGMA FOREIGN_KEYS=1;") -try: - data = con_new.execute("PRAGMA FOREIGN_KEY_CHECK;").fetchall() - if not data: - print("No Foreign Key Failures. (Good news!)") - else: - print("\nFK check fails (MUST BE FIXED):") - print("(Table, Row ID, Reference Table, (fkid) )") - for row in data: - print(row) -except sqlite3.OperationalError as e: - print("Foreign Key Check FAILED on new DB. Something may be wrong with schema.") - print(e) - -print("\nFinished! Check your database for any missing data." - " If there was a mismatch of table names, something may have been lost.") - -con_new.close() -con_old.close() diff --git a/temoa/utilities/master_migration.py b/temoa/utilities/master_migration.py index afd57016..51d35c94 100644 --- a/temoa/utilities/master_migration.py +++ b/temoa/utilities/master_migration.py @@ -1,6 +1,9 @@ import argparse +import os import re import sqlite3 + +import tempfile from pathlib import Path from typing import Any @@ -238,7 +241,7 @@ def execute_v3_to_v4_migration(con_old: sqlite3.Connection, con_new: sqlite3.Con vints = [ v[0] for v in con_old.execute( - f'SELECT vintage FROM Efficiency WHERE region="{row[0]}" AND tech="{row[1]}"' + 'SELECT vintage FROM Efficiency WHERE region=? AND tech=?', (row[0], row[1]) ).fetchall() ] for v in vints: @@ -373,19 +376,33 @@ def execute_v3_to_v4_migration(con_old: sqlite3.Connection, con_new: sqlite3.Con def migrate_database(source_path: Path, schema_path: Path, output_path: Path) -> None: - con_old = sqlite3.connect(source_path) - con_new = sqlite3.connect(output_path) - with open(schema_path, encoding='utf-8') as f: - con_new.executescript(f.read()) + fd, temp_path_str = tempfile.mkstemp(suffix='.sqlite', prefix='temp_migration_') + os.close(fd) + temp_path = Path(temp_path_str) - con_new.execute('PRAGMA foreign_keys = 0;') - execute_v3_to_v4_migration(con_old, con_new) - - con_new.commit() - con_new.execute('VACUUM;') - con_new.execute('PRAGMA foreign_keys = 1;') + con_old = sqlite3.connect(source_path) + con_new = sqlite3.connect(temp_path) + + try: + with open(schema_path, encoding='utf-8') as f: + con_new.executescript(f.read()) + + con_new.execute('PRAGMA foreign_keys = 0;') + execute_v3_to_v4_migration(con_old, con_new) + + con_new.commit() + con_new.execute('VACUUM;') + con_new.execute('PRAGMA foreign_keys = 1;') + except Exception: + con_old.close() + con_new.close() + if temp_path.exists(): + os.remove(temp_path) + raise + con_old.close() con_new.close() + os.replace(temp_path, output_path) def migrate_sql_dump(source_path: Path, schema_path: Path, output_path: Path) -> None: diff --git a/temoa/utilities/run_all_v4_migrations.py b/temoa/utilities/run_all_v4_migrations.py index a29ade1e..9ffdf1c7 100644 --- a/temoa/utilities/run_all_v4_migrations.py +++ b/temoa/utilities/run_all_v4_migrations.py @@ -37,45 +37,53 @@ def run_command( def run_migrations( - input_dir: Path, migration_script: Path, schema_path: Path, dry_run: bool = False + input_dir: Path, migration_script: Path, schema_path: Path, dry_run: bool = False, silent: bool = False ) -> None: if not input_dir.is_dir(): - print(f'Error: Input directory not found at {input_dir}') - sys.exit(1) + raise FileNotFoundError(f'Error: Input directory not found at {input_dir}') if not migration_script.is_file(): - print(f'Error: Migration script not found at {migration_script}') - sys.exit(1) + raise FileNotFoundError(f'Error: Migration script not found at {migration_script}') if not schema_path.is_file(): - print(f'Error: schema file not found at {schema_path}') - sys.exit(1) + raise FileNotFoundError(f'Error: schema file not found at {schema_path}') - print(f'Scanning for .sql and .sqlite files in: {input_dir}') + if not silent: + print(f'Scanning for .sql and .sqlite files in: {input_dir}') sql_files = list(input_dir.glob('*.sql')) db_files = list(input_dir.glob('*.sqlite')) + list(input_dir.glob('*.db')) all_files = sql_files + db_files if not all_files: - print(f'No .sql, .sqlite, or .db files found in {input_dir}. Exiting.') + if not silent: + print(f'No .sql, .sqlite, or .db files found in {input_dir}. Exiting.') return if dry_run: - print('\n--- Dry Run ---') - print(f'The following {len(all_files)} files would be processed:') - for f in all_files: - print(f' - {f.name}') - print('\nNo files will be modified in dry run mode.') + if not silent: + print('\n--- Dry Run ---') + print(f'The following {len(all_files)} files would be processed:') + for f in all_files: + print(f' - {f.name}') + print('\nNo files will be modified in dry run mode.') return - print(f'\n--- Starting Migration of {len(all_files)} files ---') + if not silent: + print(f'\n--- Starting Migration of {len(all_files)} files ---') processed_count = 0 failed_files = [] for target_file in all_files: - print(f'\nProcessing: {target_file.name}') + if not silent: + print(f'\nProcessing: {target_file.name}') ext = target_file.suffix.lower() - temp_output_file = Path(tempfile.mkstemp(suffix=ext, prefix='temp_migrated_')[1]) - original_backup_file = Path(tempfile.mkstemp(suffix='.bak', prefix='orig_backup_')[1]) + fd1, path1 = tempfile.mkstemp(suffix=ext, prefix='temp_migrated_') + os.close(fd1) + temp_output_file = Path(path1) + + fd2, path2 = tempfile.mkstemp(suffix='.bak', prefix='orig_backup_') + os.close(fd2) + original_backup_file = Path(path2) + mig_type = 'sql' if ext == '.sql' else 'db' try: @@ -84,7 +92,7 @@ def run_migrations( # 2. Run migration script, outputting to a temporary file migration_cmd = [ - 'python3', + sys.executable, str(migration_script), '--input', str(target_file), @@ -100,16 +108,19 @@ def run_migrations( if result.returncode == 0: # 3. If successful, overwrite original file shutil.copy2(temp_output_file, target_file) - print(f'SUCCESS: {target_file.name} migrated and overwritten.') + if not silent: + print(f'SUCCESS: {target_file.name} migrated and overwritten.') processed_count += 1 else: # 4. On failure, restore original file - print(f'FAILED: Migration for {target_file.name} failed. Restoring original file.') + if not silent: + print(f'FAILED: Migration for {target_file.name} failed. Restoring original file.') shutil.copy2(original_backup_file, target_file) failed_files.append(target_file.name) except Exception as e: - print(f'CRITICAL ERROR processing {target_file.name}: {e}. Restoring original file.') + if not silent: + print(f'CRITICAL ERROR processing {target_file.name}: {e}. Restoring original file.') if original_backup_file.exists(): shutil.copy2(original_backup_file, target_file) failed_files.append(target_file.name) @@ -119,13 +130,14 @@ def run_migrations( if original_backup_file.exists(): os.remove(original_backup_file) - print('\n--- Migration Summary ---') - print(f'Total files processed: {processed_count}') + if not silent: + print('\n--- Migration Summary ---') + print(f'Total files processed: {processed_count}') if failed_files: - print(f'FAILED files: {", ".join(failed_files)}') - sys.exit(1) + raise RuntimeError(f'FAILED files: {", ".join(failed_files)}') else: - print('All files migrated successfully.') + if not silent: + print('All files migrated successfully.') def main() -> None: diff --git a/temoa/utilities/sql_migration_v3_1_to_v4.py b/temoa/utilities/sql_migration_v3_1_to_v4.py deleted file mode 100644 index 357e2f17..00000000 --- a/temoa/utilities/sql_migration_v3_1_to_v4.py +++ /dev/null @@ -1,391 +0,0 @@ -#!/usr/bin/env python3 -""" -sql_migration_v_3_1_to_v4.py - -Converts a v3.1 SQL dump (text) into a valid v4 SQL dump. -This script: -1. Loads the v3.1 SQL dump into a temporary in-memory SQLite database. -2. Applies the v4 schema to a new in-memory SQLite database. -3. Programmatically queries data from the old in-memory DB, maps table/column - names using the defined rules (non-cascading, case-sensitive-first, etc.), - and inserts data into the new in-memory v4 DB. -4. Uses SQLite's built-in .dump functionality to generate the final v4 SQL dump. - -Usage: - python sql_migration_v3_1_to_v4.py --input v3_1.sql \ - --schema temoa_schema_v4.sql \ - --output v4.sql \ - [--debug] -""" - -from __future__ import annotations - -import argparse -import re -import sqlite3 -import sys - -# ------------------ Mapping configuration (mirror sqlite migrator) ------------------ -CUSTOM_MAP: dict[str, str] = { - 'TimeSeason': 'time_season', - 'time_season': 'time_season', - 'TimeSeasonSequential': 'time_season_sequential', - 'time_season_sequential': 'time_season_sequential', - 'TimeNext': 'time_manual', - 'CommodityDStreamProcess': 'commodity_down_stream_process', - 'commodityUStreamProcess': 'commodity_up_stream_process', - 'SegFrac': 'segment_fraction', - 'segfrac': 'segment_fraction', - 'MetaDataReal': 'metadata_real', - 'MetaData': 'metadata', - 'Myopicefficiency': 'myopic_efficiency', - 'DB_MAJOR': 'db_major', - 'DB_MINOR': 'db_minor', -} -CUSTOM_EXACT_ONLY = {'time_season', 'time_season_sequential'} -CUSTOM_KEYS_SORTED = sorted( - [k for k in CUSTOM_MAP.keys() if k not in CUSTOM_EXACT_ONLY], key=lambda k: -len(k) -) - - -# ------------------ Mapping functions (non-cascading) ------------------ -def to_snake_case(s: str) -> str: - if not s: - return s - if s == s.lower() and '_' in s: - return s - x = s.replace('-', '_').replace(' ', '_') - x = re.sub(r'(.)([A-Z][a-z]+)', r'\1_\2', x) - x = re.sub(r'([a-z0-9])([A-Z])', r'\1_\2', x) - x = re.sub(r'__+', '_', x) - return x.lower() - - -def map_token_no_cascade(token: str) -> str: - if not token: - return token - # prevent cascading (already a mapped output) - mapped_values = {v.lower() for v in CUSTOM_MAP.values()} - if token.lower() in mapped_values: - return token.lower() - # exact case-sensitive - if token in CUSTOM_MAP: - return CUSTOM_MAP[token].lower() - # exact case-insensitive - tl = token.lower() - for k, v in CUSTOM_MAP.items(): - if tl == k.lower(): - return v.lower() - # avoid substring replacements for PascalCase - if any(c.isupper() for c in token): - return to_snake_case(token) - # substring replacements (longest-first) - orig = token - orig_lower = orig.lower() - replacements = [(k, CUSTOM_MAP[k]) for k in CUSTOM_KEYS_SORTED if k.lower() in orig_lower] - if replacements: - out = [] - i = 0 - length = len(orig) - while i < length: - matched = False - for key, repl in replacements: - kl = len(key) - if i + kl <= length and orig[i : i + kl].lower() == key.lower(): - out.append(repl) - i += kl - matched = True - break - if not matched: - out.append(orig[i]) - i += 1 - mapped_once = ''.join(out) - mapped_once = re.sub(r'__+', '_', mapped_once).lower() - return mapped_once - return to_snake_case(token) - - -def map_table_name(table: str) -> str: - return map_token_no_cascade(table) - - -def map_column_name(col: str) -> str: - mapped = map_token_no_cascade(col) - if mapped == 'seg_frac': # Ensure canonical form for this column - mapped = 'segment_fraction' - return mapped - - -def get_table_info(conn: sqlite3.Connection, table: str) -> list[tuple]: - try: - return conn.execute(f'PRAGMA table_info({table});').fetchall() - except sqlite3.OperationalError: - return [] - - -def migrate_dump_to_sqlite(args) -> None: - # --- 1. Load v3.1 SQL dump into a temporary in-memory DB --- - print(f'Loading v3.1 SQL dump from {args.input} into in-memory DB...') - con_old_in_memory = sqlite3.connect(':memory:') - try: - with open(args.input, encoding='utf-8') as f: - v3_1_sql_dump = f.read() - con_old_in_memory.executescript(v3_1_sql_dump) - print('V3.1 dump loaded.') - except Exception as e: - print(f'ERROR: Failed to load v3.1 dump: {e}') - sys.exit(1) - - # --- 2. Create new in-memory DB and apply v4 schema --- - print(f'Applying v4 schema from {args.schema} to new in-memory DB...') - con_new_in_memory = sqlite3.connect(':memory:') - try: - with open(args.schema, encoding='utf-8') as f: - v4_schema_sql = f.read() - con_new_in_memory.executescript(v4_schema_sql) - con_new_in_memory.execute('PRAGMA foreign_keys = 0;') # Temporarily disable for migration - print('V4 schema applied.') - except Exception as e: - print(f'ERROR: Failed to apply v4 schema: {e}') - sys.exit(1) - - # Get old and new table/column info - old_tables = [ - r[0] - for r in con_old_in_memory.execute( - "SELECT name FROM sqlite_master WHERE type='table'" - ).fetchall() - if not r[0].lower().startswith('sqlite_') - ] - new_db_tables = [ - r[0] - for r in con_new_in_memory.execute( - "SELECT name FROM sqlite_master WHERE type='table'" - ).fetchall() - if not r[0].lower().startswith('sqlite_') - ] - - # --- 3. Programmatically copy data --- - total_rows_copied = 0 - for old_table_name in old_tables: - if old_table_name in ( - 'MetaData', - 'MetaDataReal', - 'TimeSeason', - 'time_season', - 'time_of_day', - 'time_season_sequential', - 'TimeSeasonSequential', - ): - if args.debug: - print(f'DEBUG: Skipping {old_table_name} (handled by custom logic)') - continue - - if old_table_name == 'CapacityFactorProcess': - # Special case: aggregate across periods - old_data = con_old_in_memory.execute( - 'SELECT region, season, tod, tech, vintage, AVG(factor) ' - 'FROM CapacityFactorProcess ' - 'GROUP BY region, season, tod, tech, vintage' - ).fetchall() - if old_data: - con_new_in_memory.executemany( - 'INSERT OR REPLACE INTO capacity_factor_process ' - '(region, season, tod, tech, vintage, factor) ' - 'VALUES (?, ?, ?, ?, ?, ?)', - old_data, - ) - print( - f'Aggregated {len(old_data)} rows: {old_table_name} -> capacity_factor_process' - ) - total_rows_copied += len(old_data) - continue - - mapped_new_table_name = map_table_name(old_table_name) - - if mapped_new_table_name not in new_db_tables: - # Tolerant fallback: if canonical target table is missing, try candidates - candidates = [ - t - for t in new_db_tables - if t.startswith(mapped_new_table_name) - or mapped_new_table_name in t - or mapped_new_table_name.replace('_', '') in t.replace('_', '') - ] - if len(candidates) == 1: - chosen_table = candidates[0] - print( - f'NOTE: Mapped target {mapped_new_table_name} not found for {old_table_name}; ' - f'using candidate {chosen_table}' - ) - mapped_new_table_name = chosen_table - else: - print( - f'SKIP: No target table for {old_table_name} -> {mapped_new_table_name} ' - f'(candidates: {candidates})' - ) - continue - - old_cols_info = get_table_info(con_old_in_memory, old_table_name) - new_cols_info = get_table_info(con_new_in_memory, mapped_new_table_name) - - if not old_cols_info: - if args.debug: - print(f'DEBUG: No column info for old table {old_table_name}') - continue - if not new_cols_info: - if args.debug: - print(f'DEBUG: No column info for new table {mapped_new_table_name}') - continue - - old_actual_cols = [c[1] for c in old_cols_info] - new_target_cols = [c[1] for c in new_cols_info] - - selectable_old_cols_for_query = [] # actual column names in old table to select - insert_target_cols_for_query = [] # mapped column names for new table's INSERT clause - - for oc in old_actual_cols: - mapped_oc = map_column_name(oc) - if mapped_oc in new_target_cols: - selectable_old_cols_for_query.append(oc) - insert_target_cols_for_query.append(mapped_oc) - - if not selectable_old_cols_for_query: - if args.debug: - print( - f'DEBUG: No common/mappable columns from {old_table_name} to ' - f'{mapped_new_table_name}. Skipping data copy.' - ) - continue - - select_query = f'SELECT {",".join(selectable_old_cols_for_query)} FROM {old_table_name}' - rows_from_old_table = con_old_in_memory.execute(select_query).fetchall() - - if not rows_from_old_table: - if args.debug: - print(f'DEBUG: No data in {old_table_name}. Skipping.') - continue - - # Filter out rows that are entirely NULL - filtered_rows_for_insert = [r for r in rows_from_old_table if any(v is not None for v in r)] - if not filtered_rows_for_insert: - if args.debug: - print(f'DEBUG: All rows from {old_table_name} were NULL. Skipping.') - continue - - placeholders = ','.join(['?'] * len(insert_target_cols_for_query)) - insert_query = ( - f'INSERT OR REPLACE INTO {mapped_new_table_name} ' - f'({",".join(insert_target_cols_for_query)}) VALUES ({placeholders})' - ) - - con_new_in_memory.executemany(insert_query, filtered_rows_for_insert) - rows_copied_this_table = len(filtered_rows_for_insert) - print(f'Copied {rows_copied_this_table} rows: {old_table_name} -> {mapped_new_table_name}') - total_rows_copied += rows_copied_this_table - - # --- 3b. Custom logic for restructured tables (Seasons/TOD) --- - print('Processing custom migration logic for seasons and time-of-day...') - - # 1. time_season (aggregate from TimeSegmentFraction) - try: - # v3.1: TimeSegmentFraction(period, season, tod, segfrac) - # v4: time_season(season, segment_fraction) - old_data = con_old_in_memory.execute( - 'SELECT season, SUM(segfrac) / COUNT(DISTINCT period) ' - 'FROM TimeSegmentFraction GROUP BY season' - ).fetchall() - if old_data: - con_new_in_memory.executemany( - 'INSERT OR REPLACE INTO time_season (season, segment_fraction) VALUES (?, ?)', - old_data, - ) - print(f'Propagated {len(old_data)} seasons to time_season.') - except sqlite3.OperationalError: - print('WARNING: Could not migrate seasons from TimeSegmentFraction (table missing?)') - - # 2. time_of_day (aggregate from TimeSegmentFraction) - try: - # v3.1: TimeSegmentFraction(period, season, tod, segfrac) -> - # tod_weights = SUM(segfrac) over season - # v4: time_of_day(tod, hours) - old_data = con_old_in_memory.execute( - 'SELECT tod, SUM(segfrac) FROM TimeSegmentFraction GROUP BY tod' - ).fetchall() - if old_data: - num_periods = ( - con_old_in_memory.execute( - 'SELECT COUNT(DISTINCT period) FROM TimeSegmentFraction' - ).fetchone()[0] - or 1 - ) - # Normalize to 24 hours - normalized_data = [(r[0], (r[1] / num_periods) * 24.0) for r in old_data] - con_new_in_memory.executemany( - 'INSERT OR REPLACE INTO time_of_day (tod, hours) VALUES (?, ?)', - normalized_data, - ) - print(f'Propagated {len(normalized_data)} time-of-day slots to time_of_day.') - except sqlite3.OperationalError: - print('WARNING: Could not migrate time_of_day from TimeSegmentFraction.') - - # 3. time_season_sequential (aggregate from TimeSeasonSequential and TimeSegmentFraction) - # This is tricky because v3.1 had period-dependent sequential seasons. - # We take the first available period's definition. - try: - # v3.1: TimeSeasonSequential(period, sequence, seas_seq, season, num_days) - # v4: time_season_sequential(seas_seq, season, segment_fraction) - first_period = con_old_in_memory.execute( - 'SELECT MIN(period) FROM TimeSeasonSequential' - ).fetchone()[0] - if first_period: - old_data = con_old_in_memory.execute( - 'SELECT tss.seas_seq, tss.season, (tss.num_days / 365.25) ' - 'FROM TimeSeasonSequential tss ' - 'WHERE tss.period = ?', - (first_period,), - ).fetchall() - if old_data: - con_new_in_memory.executemany( - 'INSERT OR REPLACE INTO time_season_sequential ' - '(seas_seq, season, segment_fraction) VALUES (?, ?, ?)', - old_data, - ) - print(f'Propagated {len(old_data)} sequential seasons to time_season.') - except sqlite3.OperationalError: - pass # Optional table - - # --- Final updates and dump --- - con_new_in_memory.execute("INSERT OR REPLACE INTO metadata VALUES ('DB_MAJOR', 4, '')") - con_new_in_memory.execute("INSERT OR REPLACE INTO metadata VALUES ('DB_MINOR', 0, '')") - con_new_in_memory.commit() - con_new_in_memory.execute('PRAGMA foreign_keys = 1;') # Re-enable FKs - - # Generate the v4 SQL dump - print(f'Generating v4 SQL dump to {args.output}...') - with open(args.output, 'w', encoding='utf-8') as f_out: - for line in con_new_in_memory.iterdump(): - # Add back "PRAGMA foreign_keys=OFF;" and "BEGIN TRANSACTION;" at the start if missing - # And "COMMIT;" at the end. - # It seems .iterdump() already adds PRAGMA and BEGIN/COMMIT. - f_out.write(line + '\n') - - con_old_in_memory.close() - con_new_in_memory.close() - print( - f'Conversion complete. Total rows copied: {total_rows_copied}. Output dump: {args.output}' - ) - - -def parse_cli() -> argparse.Namespace: - p = argparse.ArgumentParser() - p.add_argument('--input', '-i', required=True, help='Path to v3.1 SQL dump file') - p.add_argument('--schema', '-s', required=True, help='Path to v4 schema SQL file') - p.add_argument('--output', '-o', required=True, help='Path for output v4 SQL dump file') - p.add_argument('--debug', action='store_true', help='Enable debug output') - return p.parse_args() - - -if __name__ == '__main__': - args = parse_cli() - migrate_dump_to_sqlite(args) diff --git a/tests/test_cli.py b/tests/test_cli.py index 7bb8468f..00184115 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -176,6 +176,11 @@ def test_cli_migrate_accepts_directory_input(tmp_path: Path) -> None: """Test that the migrate command accepts a directory as input for batch processing.""" dummy_dir = tmp_path / 'my_dummy_dir' dummy_dir.mkdir() + + src_file = UTOPIA_SQL_FIXTURE + input_file = dummy_dir / src_file.name + shutil.copy2(src_file, input_file) + args = ['migrate', str(dummy_dir)] result = runner.invoke(app, args) @@ -185,7 +190,7 @@ def test_cli_migrate_accepts_directory_input(tmp_path: Path) -> None: assert 'Batch migrating directory' in normalized_output # Check for the directory name in the original output (paths may be split across lines) assert 'my_dummy_dir' in result.stdout - assert 'No .sql, .sqlite, or .db files found' in result.stdout + assert 'Total files processed: 1' in normalized_output def test_cli_migrate_sql_file_auto_output_writable_input_dir(tmp_path: Path) -> None: diff --git a/tests/test_v4_migration.py b/tests/test_v4_migration.py index 534ee211..bc0820b8 100644 --- a/tests/test_v4_migration.py +++ b/tests/test_v4_migration.py @@ -8,20 +8,30 @@ # Constants REPO_ROOT = Path(__file__).parents[1] UTILITIES_DIR = REPO_ROOT / 'temoa' / 'utilities' -SCHEMA_V3_1 = REPO_ROOT / 'temoa' / 'db_schema' / 'temoa_schema_v3_1.sql' SCHEMA_V4 = REPO_ROOT / 'temoa' / 'db_schema' / 'temoa_schema_v4.sql' -MOCK_DATA = REPO_ROOT / 'tests' / 'testing_data' / 'migration_v3_1_mock.sql' - -def test_v4_migrations(tmp_path: Path) -> None: +SCHEMA_V3 = REPO_ROOT / 'temoa' / 'db_schema' / 'temoa_schema_v3.sql' +SCHEMA_V3_1 = REPO_ROOT / 'temoa' / 'db_schema' / 'temoa_schema_v3_1.sql' +MOCK_DATA_V3 = REPO_ROOT / 'tests' / 'testing_data' / 'migration_v3_mock.sql' +MOCK_DATA_V3_1 = REPO_ROOT / 'tests' / 'testing_data' / 'migration_v3_1_mock.sql' + + +@pytest.mark.parametrize( + 'schema_file, mock_data_file', + [ + (SCHEMA_V3, MOCK_DATA_V3), + (SCHEMA_V3_1, MOCK_DATA_V3_1), + ] +) +def test_v4_migrations(tmp_path: Path, schema_file: Path, mock_data_file: Path) -> None: """Test both SQL and SQLite master migrators.""" - # 1. Create v3.1 SQLite DB (acts as our v3 mock since the mock schema is 3.1) - db_v3_1 = tmp_path / 'test_v3_1.sqlite' + # 1. Create SQLite DB + db_v3_1 = tmp_path / 'test_db.sqlite' with sqlite3.connect(db_v3_1) as conn: conn.execute('PRAGMA foreign_keys = OFF') - conn.executescript(SCHEMA_V3_1.read_text()) - conn.executescript(MOCK_DATA.read_text()) + conn.executescript(schema_file.read_text()) + conn.executescript(mock_data_file.read_text()) conn.execute('PRAGMA foreign_keys = ON') # 2. Dump to SQL diff --git a/tests/testing_data/migration_v3_mock.sql b/tests/testing_data/migration_v3_mock.sql new file mode 100644 index 00000000..55a97551 --- /dev/null +++ b/tests/testing_data/migration_v3_mock.sql @@ -0,0 +1,30 @@ +-- Mock data for v3 -> v4 migration testing +INSERT INTO Region (region) VALUES ('R1'); +INSERT OR IGNORE INTO TechnologyType (label, description) VALUES ('p', 'production'); +INSERT INTO Technology (tech, flag, unlim_cap, annual, reserve, curtail, retire, flex, exchange) VALUES ('T1', 'p', 0, 0, 0, 0, 0, 0, 0); + +INSERT INTO TimePeriodType (label, description) VALUES ('e', 'existing'); +INSERT INTO TimePeriodType (label, description) VALUES ('f', 'future'); +INSERT INTO TimePeriod (sequence, period, flag) VALUES (1, 2020, 'e'); +INSERT INTO TimePeriod (sequence, period, flag) VALUES (2, 2030, 'f'); +INSERT INTO TimePeriod (sequence, period, flag) VALUES (3, 2040, 'f'); + + + +INSERT INTO TimeOfDay (sequence, tod) VALUES (1, 'day'); +INSERT INTO TimeOfDay (sequence, tod) VALUES (2, 'night'); + +INSERT INTO TimeSeason (sequence, season) VALUES (1, 'winter'); +INSERT INTO TimeSeason (sequence, season) VALUES (2, 'summer'); + +INSERT INTO TimeSegmentFraction (season, tod, segfrac) VALUES ('winter', 'day', 0.2); +INSERT INTO TimeSegmentFraction (season, tod, segfrac) VALUES ('winter', 'night', 0.1); +INSERT INTO TimeSegmentFraction (season, tod, segfrac) VALUES ('summer', 'day', 0.4); +INSERT INTO TimeSegmentFraction (season, tod, segfrac) VALUES ('summer', 'night', 0.3); + +INSERT OR IGNORE INTO CommodityType (label, description) VALUES ('p', 'physical'); +INSERT INTO Commodity (name, flag) VALUES ('In', 'p'); +INSERT INTO Commodity (name, flag) VALUES ('Out', 'p'); + +INSERT INTO CapacityFactorProcess (region, season, tod, tech, vintage, factor) VALUES ('R1', 'winter', 'day', 'T1', 2030, 0.6); +INSERT INTO Efficiency (region, input_comm, tech, vintage, output_comm, efficiency) VALUES ('R1', 'In', 'T1', 2030, 'Out', 0.9); From 05f22cdddb24baa59968bad5dab54a1efdb1b7eb Mon Sep 17 00:00:00 2001 From: ParticularlyPythonicBS Date: Wed, 25 Mar 2026 13:08:30 -0400 Subject: [PATCH 4/5] Harden migration refactor: modular helpers, atomic SQL exports, and expanded test coverage --- temoa/utilities/master_migration.py | 108 ++++++++++++++++------- temoa/utilities/run_all_v4_migrations.py | 10 +-- tests/test_v4_migration.py | 2 +- tests/testing_data/migration_v3_mock.sql | 2 + 4 files changed, 82 insertions(+), 40 deletions(-) diff --git a/temoa/utilities/master_migration.py b/temoa/utilities/master_migration.py index 51d35c94..43418eb4 100644 --- a/temoa/utilities/master_migration.py +++ b/temoa/utilities/master_migration.py @@ -108,9 +108,7 @@ def map_token_no_cascade(token: str) -> str: if not matched: out.append(orig[i]) i += 1 - mapped_once = ''.join(out) - mapped_once = re.sub(r'__+', '_', mapped_once).lower() - return mapped_once + return re.sub(r'__+', '_', ''.join(out)).lower() return to_snake_case(token) @@ -121,19 +119,10 @@ def get_table_info(conn: sqlite3.Connection, table: str) -> list[tuple[Any, ...] return [] -def execute_v3_to_v4_migration(con_old: sqlite3.Connection, con_new: sqlite3.Connection) -> None: - old_tables = [ - r[0] - for r in con_old.execute("SELECT name FROM sqlite_master WHERE type='table'").fetchall() - ] - new_tables = [ - r[0] - for r in con_new.execute("SELECT name FROM sqlite_master WHERE type='table'").fetchall() - ] - total = 0 - - # 1. Handle operator-added tables +def _migrate_operator_tables(con_old: sqlite3.Connection, con_new: sqlite3.Connection) -> int: + """Migrate max/min tables to operator constraints.""" print('--- Migrating max/min tables to operator constraints ---') + total = 0 for old_name, (new_name, operator) in OPERATOR_ADDED_TABLES.items(): try: data = con_old.execute(f'SELECT * FROM {old_name}').fetchall() @@ -147,6 +136,8 @@ def execute_v3_to_v4_migration(con_old: sqlite3.Connection, con_new: sqlite3.Con continue op_index = new_cols.index('operator') + assert 0 <= op_index < len(new_cols), f'Operator column missing or invalid for {new_name}' + data = [(*row[0:op_index], operator, *row[op_index : len(new_cols) - 1]) for row in data] # Move period to vintage if applicable @@ -170,8 +161,21 @@ def execute_v3_to_v4_migration(con_old: sqlite3.Connection, con_new: sqlite3.Con con_new.executemany(query, data) print(f'Migrated {len(data)} rows: {old_name} -> {new_name}') total += len(data) + return total + - # 2. Standard directory / copied tables +def _migrate_standard_tables(con_old: sqlite3.Connection, con_new: sqlite3.Connection) -> int: + """Migrate standard tables by mapping names and columns.""" + print('--- Executing standard table migrations ---') + total = 0 + old_tables = [ + r[0] + for r in con_old.execute("SELECT name FROM sqlite_master WHERE type='table'").fetchall() + ] + new_tables = [ + r[0] + for r in con_new.execute("SELECT name FROM sqlite_master WHERE type='table'").fetchall() + ] custom_handled_old_tables = { 'MetaData', 'MetaDataReal', @@ -185,7 +189,6 @@ def execute_v3_to_v4_migration(con_old: sqlite3.Connection, con_new: sqlite3.Con 'CapacityFactorProcess', }.union(OPERATOR_ADDED_TABLES.keys()) - print('--- Executing standard table migrations ---') for old in old_tables: if old.lower().startswith('sqlite_') or old in custom_handled_old_tables: continue @@ -226,11 +229,12 @@ def execute_v3_to_v4_migration(con_old: sqlite3.Connection, con_new: sqlite3.Con con_new.executemany(q, filtered) print(f'Copied {len(filtered)} rows: {old} -> {new}') total += len(filtered) + return total - # 3. Custom specific logics - print('--- Processing custom migration logic ---') - # 3.1 LoanLifetimeTech -> loan_lifetime_process +def _migrate_loan_lifetime(con_old: sqlite3.Connection, con_new: sqlite3.Connection) -> int: + """Migrate LoanLifetimeTech to loan_lifetime_process.""" + total = 0 try: data = con_old.execute( 'SELECT region, tech, lifetime, notes FROM LoanLifetimeTech' @@ -254,8 +258,13 @@ def execute_v3_to_v4_migration(con_old: sqlite3.Connection, con_new: sqlite3.Con total += len(new_data) except sqlite3.OperationalError: pass + return total - # 3.2 time_season (aggregate from TimeSegmentFraction) + +def _migrate_time_tables(con_old: sqlite3.Connection, con_new: sqlite3.Connection) -> int: + """Migrate time-related tables (season, tod, sequential).""" + total = 0 + # 1. time_season try: old_data = [] cols = [c[1] for c in get_table_info(con_old, 'TimeSegmentFraction')] @@ -278,7 +287,7 @@ def execute_v3_to_v4_migration(con_old: sqlite3.Connection, con_new: sqlite3.Con except sqlite3.OperationalError: pass - # 3.3 time_of_day (aggregate from TimeSegmentFraction) + # 2. time_of_day try: old_data = [] cols = [c[1] for c in get_table_info(con_old, 'TimeSegmentFraction')] @@ -313,7 +322,7 @@ def execute_v3_to_v4_migration(con_old: sqlite3.Connection, con_new: sqlite3.Con except sqlite3.OperationalError: pass - # 3.4 time_season_sequential + # 3. time_season_sequential try: old_data = [] cols = [c[1] for c in get_table_info(con_old, 'TimeSeasonSequential')] @@ -341,8 +350,12 @@ def execute_v3_to_v4_migration(con_old: sqlite3.Connection, con_new: sqlite3.Con total += len(old_data) except sqlite3.OperationalError: pass + return total + - # 3.5 CapacityFactorProcess +def _migrate_capacity_factor(con_old: sqlite3.Connection, con_new: sqlite3.Connection) -> int: + """Migrate CapacityFactorProcess.""" + total = 0 try: old_data = [] cols = [c[1] for c in get_table_info(con_old, 'CapacityFactorProcess')] @@ -367,8 +380,21 @@ def execute_v3_to_v4_migration(con_old: sqlite3.Connection, con_new: sqlite3.Con total += len(old_data) except sqlite3.OperationalError: pass + return total + + +def execute_v3_to_v4_migration(con_old: sqlite3.Connection, con_new: sqlite3.Connection) -> None: + """Main migration logic router.""" + total = 0 + total += _migrate_operator_tables(con_old, con_new) + total += _migrate_standard_tables(con_old, con_new) + + print('--- Processing custom migration logic ---') + total += _migrate_loan_lifetime(con_old, con_new) + total += _migrate_time_tables(con_old, con_new) + total += _migrate_capacity_factor(con_old, con_new) - # 4. Final Updates + # Final Updates con_new.execute("UPDATE technology SET flag='p' WHERE flag='r';") con_new.execute("INSERT OR REPLACE INTO metadata VALUES ('DB_MAJOR', 4, '')") con_new.execute("INSERT OR REPLACE INTO metadata VALUES ('DB_MINOR', 0, '')") @@ -382,14 +408,14 @@ def migrate_database(source_path: Path, schema_path: Path, output_path: Path) -> con_old = sqlite3.connect(source_path) con_new = sqlite3.connect(temp_path) - + try: with open(schema_path, encoding='utf-8') as f: con_new.executescript(f.read()) - + con_new.execute('PRAGMA foreign_keys = 0;') execute_v3_to_v4_migration(con_old, con_new) - + con_new.commit() con_new.execute('VACUUM;') con_new.execute('PRAGMA foreign_keys = 1;') @@ -399,7 +425,7 @@ def migrate_database(source_path: Path, schema_path: Path, output_path: Path) -> if temp_path.exists(): os.remove(temp_path) raise - + con_old.close() con_new.close() os.replace(temp_path, output_path) @@ -420,12 +446,26 @@ def migrate_sql_dump(source_path: Path, schema_path: Path, output_path: Path) -> con_new_in_memory.commit() con_new_in_memory.execute('PRAGMA foreign_keys = 1;') - with open(output_path, 'w', encoding='utf-8') as f_out: - for line in con_new_in_memory.iterdump(): - f_out.write(line + '\n') + fd, temp_path_str = tempfile.mkstemp( + suffix='.sql', prefix='temp_sql_export_', dir=output_path.parent + ) + temp_path = Path(temp_path_str) + + try: + with os.fdopen(fd, 'w', encoding='utf-8') as f_out: + for line in con_new_in_memory.iterdump(): + f_out.write(line + '\n') + f_out.flush() + os.fsync(f_out.fileno()) - con_old_in_memory.close() - con_new_in_memory.close() + os.replace(temp_path, output_path) + except Exception: + if temp_path.exists(): + os.remove(temp_path) + raise + finally: + con_old_in_memory.close() + con_new_in_memory.close() if __name__ == '__main__': diff --git a/temoa/utilities/run_all_v4_migrations.py b/temoa/utilities/run_all_v4_migrations.py index 9ffdf1c7..1109f3fb 100644 --- a/temoa/utilities/run_all_v4_migrations.py +++ b/temoa/utilities/run_all_v4_migrations.py @@ -79,11 +79,11 @@ def run_migrations( fd1, path1 = tempfile.mkstemp(suffix=ext, prefix='temp_migrated_') os.close(fd1) temp_output_file = Path(path1) - + fd2, path2 = tempfile.mkstemp(suffix='.bak', prefix='orig_backup_') os.close(fd2) original_backup_file = Path(path2) - + mig_type = 'sql' if ext == '.sql' else 'db' try: @@ -135,9 +135,9 @@ def run_migrations( print(f'Total files processed: {processed_count}') if failed_files: raise RuntimeError(f'FAILED files: {", ".join(failed_files)}') - else: - if not silent: - print('All files migrated successfully.') + + if not silent: + print('All files migrated successfully.') def main() -> None: diff --git a/tests/test_v4_migration.py b/tests/test_v4_migration.py index bc0820b8..a4519077 100644 --- a/tests/test_v4_migration.py +++ b/tests/test_v4_migration.py @@ -17,7 +17,7 @@ @pytest.mark.parametrize( - 'schema_file, mock_data_file', + ('schema_file', 'mock_data_file'), [ (SCHEMA_V3, MOCK_DATA_V3), (SCHEMA_V3_1, MOCK_DATA_V3_1), diff --git a/tests/testing_data/migration_v3_mock.sql b/tests/testing_data/migration_v3_mock.sql index 55a97551..28973eb2 100644 --- a/tests/testing_data/migration_v3_mock.sql +++ b/tests/testing_data/migration_v3_mock.sql @@ -28,3 +28,5 @@ INSERT INTO Commodity (name, flag) VALUES ('Out', 'p'); INSERT INTO CapacityFactorProcess (region, season, tod, tech, vintage, factor) VALUES ('R1', 'winter', 'day', 'T1', 2030, 0.6); INSERT INTO Efficiency (region, input_comm, tech, vintage, output_comm, efficiency) VALUES ('R1', 'In', 'T1', 2030, 'Out', 0.9); +INSERT INTO MinCapacity (region, tech, period, min_cap, units, notes) VALUES ('R1', 'T1', 2030, 10.0, 'GW', 'test op'); +INSERT INTO EmissionLimit (region, period, emis_comm, value, units, notes) VALUES ('R1', 2030, 'Out', 100.0, 'kt', 'test op'); From a20dc7cb53fa060137e2441b49965e32f96865d9 Mon Sep 17 00:00:00 2001 From: ParticularlyPythonicBS Date: Wed, 25 Mar 2026 14:06:23 -0400 Subject: [PATCH 5/5] Final hardening: cross-fs safety, input validation, and operator-table verification --- temoa/utilities/master_migration.py | 25 ++++++++++++++++------ temoa/utilities/run_all_v4_migrations.py | 9 ++++---- tests/test_v4_migration.py | 10 +++++++++ tests/testing_data/migration_v3_1_mock.sql | 2 ++ 4 files changed, 36 insertions(+), 10 deletions(-) diff --git a/temoa/utilities/master_migration.py b/temoa/utilities/master_migration.py index 43418eb4..154eb392 100644 --- a/temoa/utilities/master_migration.py +++ b/temoa/utilities/master_migration.py @@ -402,7 +402,14 @@ def execute_v3_to_v4_migration(con_old: sqlite3.Connection, con_new: sqlite3.Con def migrate_database(source_path: Path, schema_path: Path, output_path: Path) -> None: - fd, temp_path_str = tempfile.mkstemp(suffix='.sqlite', prefix='temp_migration_') + if not source_path.is_file(): + raise FileNotFoundError(f"Input database not found: {source_path}") + if not schema_path.is_file(): + raise FileNotFoundError(f"Schema file not found: {schema_path}") + + fd, temp_path_str = tempfile.mkstemp( + suffix='.sqlite', prefix='temp_migration_', dir=output_path.parent + ) os.close(fd) temp_path = Path(temp_path_str) @@ -419,19 +426,25 @@ def migrate_database(source_path: Path, schema_path: Path, output_path: Path) -> con_new.commit() con_new.execute('VACUUM;') con_new.execute('PRAGMA foreign_keys = 1;') - except Exception: + con_old.close() con_new.close() + os.replace(temp_path, output_path) + except Exception: if temp_path.exists(): os.remove(temp_path) raise - - con_old.close() - con_new.close() - os.replace(temp_path, output_path) + finally: + con_old.close() + con_new.close() def migrate_sql_dump(source_path: Path, schema_path: Path, output_path: Path) -> None: + if not source_path.is_file(): + raise FileNotFoundError(f"Input SQL dump not found: {source_path}") + if not schema_path.is_file(): + raise FileNotFoundError(f"Schema file not found: {schema_path}") + con_old_in_memory = sqlite3.connect(':memory:') with open(source_path, encoding='utf-8') as f: con_old_in_memory.executescript(f.read()) diff --git a/temoa/utilities/run_all_v4_migrations.py b/temoa/utilities/run_all_v4_migrations.py index 1109f3fb..957bd113 100644 --- a/temoa/utilities/run_all_v4_migrations.py +++ b/temoa/utilities/run_all_v4_migrations.py @@ -24,12 +24,13 @@ def run_command( - cmd: list[str], cwd: Path | None = None, capture_output: bool = True + cmd: list[str], cwd: Path | None = None, capture_output: bool = True, silent: bool = False ) -> subprocess.CompletedProcess[str]: """Helper to run shell commands.""" - print(f'Executing: {" ".join(cmd)}') + if not silent: + print(f'Executing: {" ".join(cmd)}') result = subprocess.run(cmd, cwd=cwd, capture_output=capture_output, text=True, check=False) - if result.returncode != 0 and capture_output: + if not silent and result.returncode != 0 and capture_output: print(f'COMMAND FAILED (exit code {result.returncode}):') print('STDOUT:\n', result.stdout) print('STDERR:\n', result.stderr) @@ -103,7 +104,7 @@ def run_migrations( '--type', mig_type, ] - result = run_command(migration_cmd, cwd=Path.cwd()) + result = run_command(migration_cmd, cwd=Path.cwd(), silent=silent) if result.returncode == 0: # 3. If successful, overwrite original file diff --git a/tests/test_v4_migration.py b/tests/test_v4_migration.py index a4519077..42e64963 100644 --- a/tests/test_v4_migration.py +++ b/tests/test_v4_migration.py @@ -114,6 +114,16 @@ def _verify_migrated_data(conn: sqlite3.Connection) -> None: assert len(rows) == 1 assert rows[0] == ('R1', 'winter', 'day', 'T1', 2030, pytest.approx(0.6)) + # Check operator-added tables (limit_capacity and limit_emission) + cap_rows = conn.execute('SELECT region, tech_or_group, period, capacity, operator FROM limit_capacity').fetchall() + assert len(cap_rows) == 1 + assert cap_rows[0] == ('R1', 'T1', 2030, 10.0, 'ge') + + emis_rows = conn.execute('SELECT region, period, value, operator FROM limit_emission').fetchall() + assert len(emis_rows) == 1 + assert emis_rows[0] == ('R1', 2030, 100.0, 'le') + # Check metadata version major = conn.execute("SELECT value FROM metadata WHERE element='DB_MAJOR'").fetchone()[0] assert int(major) == 4 + assert int(conn.execute("SELECT value FROM metadata WHERE element='DB_MINOR'").fetchone()[0]) == 0 diff --git a/tests/testing_data/migration_v3_1_mock.sql b/tests/testing_data/migration_v3_1_mock.sql index 8e97007a..9a282962 100644 --- a/tests/testing_data/migration_v3_1_mock.sql +++ b/tests/testing_data/migration_v3_1_mock.sql @@ -28,3 +28,5 @@ INSERT INTO Commodity (name, flag) VALUES ('Out', 'p'); INSERT INTO CapacityFactorProcess (region, period, season, tod, tech, vintage, factor) VALUES ('R1', 2030, 'winter', 'day', 'T1', 2030, 0.5); INSERT INTO CapacityFactorProcess (region, period, season, tod, tech, vintage, factor) VALUES ('R1', 2040, 'winter', 'day', 'T1', 2030, 0.7); INSERT INTO Efficiency (region, input_comm, tech, vintage, output_comm, efficiency) VALUES ('R1', 'In', 'T1', 2030, 'Out', 0.9); +INSERT INTO LimitCapacity (region, period, tech_or_group, operator, capacity, units, notes) VALUES ('R1', 2030, 'T1', 'ge', 10.0, 'GW', 'test op'); +INSERT INTO LimitEmission (region, period, emis_comm, operator, value, units, notes) VALUES ('R1', 2030, 'Out', 'le', 100.0, 'kt', 'test op');