88from collections import defaultdict
99from typing import List , Tuple , Iterator , Optional
1010import logging
11- from concurrent .futures import ThreadPoolExecutor
11+ from concurrent .futures import ThreadPoolExecutor , as_completed
1212
1313from runtype import dataclass
1414
@@ -315,17 +315,16 @@ def diff_tables(self, table1: TableSegment, table2: TableSegment) -> DiffResult:
315315 ('-', columns) for items in table2 but not in table1
316316 Where `columns` is a tuple of values for the involved columns, i.e. (id, ...extra)
317317 """
318+ # Validate options
318319 if self .bisection_factor >= self .bisection_threshold :
319320 raise ValueError ("Incorrect param values (bisection factor must be lower than threshold)" )
320321 if self .bisection_factor < 2 :
321322 raise ValueError ("Must have at least two segments per iteration (i.e. bisection_factor >= 2)" )
322323
324+ # Query and validate schema
323325 table1 , table2 = self ._threaded_call ("with_schema" , [table1 , table2 ])
324326 self ._validate_and_adjust_columns (table1 , table2 )
325327
326- key_ranges = self ._threaded_call ("query_key_range" , [table1 , table2 ])
327- mins , maxs = zip (* key_ranges )
328-
329328 key_type = table1 ._schema [table1 .key_column ]
330329 key_type2 = table2 ._schema [table2 .key_column ]
331330 if not isinstance (key_type , IKey ):
@@ -334,23 +333,42 @@ def diff_tables(self, table1: TableSegment, table2: TableSegment) -> DiffResult:
334333 raise NotImplementedError (f"Cannot use column of type { key_type2 } as a key" )
335334 assert key_type .python_type is key_type2 .python_type
336335
337- # We add 1 because our ranges are exclusive of the end (like in Python)
338- try :
339- min_key = min (map (key_type .python_type , mins ))
340- max_key = max (map (key_type .python_type , maxs )) + 1
341- except (TypeError , ValueError ) as e :
342- raise type (e )(f"Cannot apply { key_type } to { mins } , { maxs } ." ) from e
336+ # Query min/max values
337+ key_ranges = self ._threaded_call_as_completed ("query_key_range" , [table1 , table2 ])
343338
344- table1 = table1 .new (min_key = min_key , max_key = max_key )
345- table2 = table2 .new (min_key = min_key , max_key = max_key )
339+ # Start with the first completed value, so we don't waste time waiting
340+ min_key1 , max_key1 = self ._parse_key_range_result (key_type , next (key_ranges ))
341+
342+ table1 , table2 = [t .new (min_key = min_key1 , max_key = max_key1 ) for t in (table1 , table2 )]
346343
347344 logger .info (
348345 f"Diffing tables | segments: { self .bisection_factor } , bisection threshold: { self .bisection_threshold } . "
349346 f"key-range: { table1 .min_key } ..{ table2 .max_key } , "
350347 f"size: { table2 .max_key - table1 .min_key } "
351348 )
352349
353- return self ._bisect_and_diff_tables (table1 , table2 )
350+ # Bisect (split) the table into segments, and diff them recursively.
351+ yield from self ._bisect_and_diff_tables (table1 , table2 )
352+
353+ # Now we check for the second min-max, to diff the portions we "missed".
354+ min_key2 , max_key2 = self ._parse_key_range_result (key_type , next (key_ranges ))
355+
356+ if min_key2 < min_key1 :
357+ pre_tables = [t .new (min_key = min_key2 , max_key = min_key1 ) for t in (table1 , table2 )]
358+ yield from self ._bisect_and_diff_tables (* pre_tables )
359+
360+ if max_key2 > max_key1 :
361+ post_tables = [t .new (min_key = max_key1 , max_key = max_key2 ) for t in (table1 , table2 )]
362+ yield from self ._bisect_and_diff_tables (* post_tables )
363+
364+ def _parse_key_range_result (self , key_type , key_range ):
365+ mn , mx = key_range
366+ cls = key_type .python_type
367+ # We add 1 because our ranges are exclusive of the end (like in Python)
368+ try :
369+ return cls (mn ), cls (mx ) + 1
370+ except (TypeError , ValueError ) as e :
371+ raise type (e )(f"Cannot apply { key_type } to { mn } , { mx } ." ) from e
354372
355373 def _validate_and_adjust_columns (self , table1 , table2 ):
356374 for c in table1 ._relevant_columns :
@@ -474,12 +492,26 @@ def _diff_tables(self, table1, table2, level=0, segment_index=None, segment_coun
474492 if checksum1 != checksum2 :
475493 yield from self ._bisect_and_diff_tables (table1 , table2 , level = level , max_rows = max (count1 , count2 ))
476494
477- def _thread_map (self , func , iter ):
495+ def _thread_map (self , func , iterable ):
496+ if not self .threaded :
497+ return map (func , iterable )
498+
499+ with ThreadPoolExecutor (max_workers = self .max_threadpool_size ) as task_pool :
500+ return task_pool .map (func , iterable )
501+
502+ def _threaded_call (self , func , iterable ):
503+ "Calls a method for each object in iterable."
504+ return list (self ._thread_map (methodcaller (func ), iterable ))
505+
506+ def _thread_as_completed (self , func , iterable ):
478507 if not self .threaded :
479- return map (func , iter )
508+ return map (func , iterable )
480509
481- task_pool = ThreadPoolExecutor (max_workers = self .max_threadpool_size )
482- return task_pool .map (func , iter )
510+ with ThreadPoolExecutor (max_workers = self .max_threadpool_size ) as task_pool :
511+ futures = [task_pool .submit (func , item ) for item in iterable ]
512+ for future in as_completed (futures ):
513+ yield future .result ()
483514
484- def _threaded_call (self , func , iter ):
485- return list (self ._thread_map (methodcaller (func ), iter ))
515+ def _threaded_call_as_completed (self , func , iterable ):
516+ "Calls a method for each object in iterable. Returned in order of completion."
517+ return self ._thread_as_completed (methodcaller (func ), iterable )
0 commit comments