11from collections import Counter
22from logging import getLogger
3+ from typing import Optional , Tuple
34import pandas
45import numpy
56from .dataframe_helpers import dataframe_shuffle
@@ -447,14 +448,15 @@ def double_merge(d):
447448
448449
449450def train_test_apart_stratify (
450- df ,
451+ df : pandas . DataFrame ,
451452 group ,
452- test_size = 0.25 ,
453- train_size = None ,
454- stratify = None ,
455- force = False ,
456- random_state = None ,
457- ):
453+ test_size : Optional [float ] = 0.25 ,
454+ train_size : Optional [float ] = None ,
455+ stratify : Optional [str ] = None ,
456+ force : bool = False ,
457+ random_state : Optional [int ] = None ,
458+ sorted_indices : bool = False ,
459+ ) -> Tuple ["StreamingDataFrame" , "StreamingDataFrame" ]: # noqa: F821
458460 """
459461 This split is for a specific case where data is linked
460462 in one way. Let's assume we have two ids as we have
@@ -472,6 +474,8 @@ def train_test_apart_stratify(
472474 :param force: if True, tries to get at least one example on the test side
473475 for each value of the column *stratify*
474476 :param random_state: seed for random generators
477+ :param sorted_indices: sort index first,
478+ see issue `41 <https://github.com/sdpython/pandas-streaming/issues/41>`
475479 :return: Two see :class:`StreamingDataFrame
476480 <pandas_streaming.df.dataframe.StreamingDataFrame>`, one
477481 for train, one for test.
@@ -538,10 +542,15 @@ def train_test_apart_stratify(
538542
539543 split = {}
540544 for _ , k in sorted_hist :
541- not_assigned = [c for c in ids [k ] if c not in split ]
545+ indices = sorted (ids [k ]) if sorted_indices else ids [k ]
546+ not_assigned , assigned = [], []
547+ for c in indices :
548+ if c in split :
549+ assigned .append (c )
550+ else :
551+ not_assigned .append (c )
542552 if len (not_assigned ) == 0 :
543553 continue
544- assigned = [c for c in ids [k ] if c in split ]
545554 nb_test = sum (split [c ] for c in assigned )
546555 expected = min (len (ids [k ]), int (test_size * len (ids [k ]) + 0.5 )) - nb_test
547556 if force and expected == 0 and nb_test == 0 :
0 commit comments