SNOW-1051741: df.apply(axis=1) should preserve the original index#3955
SNOW-1051741: df.apply(axis=1) should preserve the original index#3955sfc-gh-jkew wants to merge 21 commits intomainfrom
Conversation
sfc-gh-helmeleegy
left a comment
There was a problem hiding this comment.
LGTM, just had one question.
src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py
Outdated
Show resolved
Hide resolved
…b/snowpark-python into jkew/apply.axis.1.row.index.0
| if num_index_columns > 0: | ||
| # Columns after row position are index columns, then data columns | ||
| index_cols = df.iloc[:, 1 : 1 + num_index_columns] | ||
| data_cols = df.iloc[:, 1 + num_index_columns :] | ||
|
|
||
| # Set the index using the index columns | ||
| if num_index_columns == 1: | ||
| index = index_cols.iloc[:, 0] | ||
| if index_column_pandas_labels: | ||
| index.name = index_column_pandas_labels[0] | ||
| else: | ||
| # Multi-index case | ||
| index = native_pd.MultiIndex.from_arrays( | ||
| [index_cols.iloc[:, i] for i in range(num_index_columns)], | ||
| names=index_column_pandas_labels | ||
| if index_column_pandas_labels | ||
| else None, | ||
| ) | ||
| data_cols.index = index | ||
| df = data_cols | ||
| else: |
There was a problem hiding this comment.
can't we use set_index() in both cases?
There was a problem hiding this comment.
I meant that you can replace most of the code here with set_index(). See #3979.
src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py
Outdated
Show resolved
Hide resolved
src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py
Outdated
Show resolved
Hide resolved
…b/snowpark-python into jkew/apply.axis.1.row.index.0
…b/snowpark-python into jkew/apply.axis.1.row.index.0
| input_types: Snowpark column types of the input data columns (including index columns). | ||
| index_column_pandas_labels: The pandas labels for the index columns, if any. |
There was a problem hiding this comment.
| input_types: Snowpark column types of the input data columns (including index columns). | |
| index_column_pandas_labels: The pandas labels for the index columns, if any. | |
| input_types: Snowpark column types of the input data columns (including index columns). |
|
|
||
|
|
||
| @sql_count_checker(query_count=5, join_count=2, udtf_count=1) | ||
| def test_apply_axis_1_multiindex_preservation(): |
There was a problem hiding this comment.
Could we also test
funcwith return type annotations. We'll use vectorized UDFs instead of UDTFs.funcreturning a series- apply() on series (with func typed, untyped, or returning a series)
| if num_index_columns > 0: | ||
| # Columns after row position are index columns, then data columns | ||
| index_cols = df.iloc[:, 1 : 1 + num_index_columns] | ||
| data_cols = df.iloc[:, 1 + num_index_columns :] | ||
|
|
||
| # Set the index using the index columns | ||
| if num_index_columns == 1: | ||
| index = index_cols.iloc[:, 0] | ||
| if index_column_pandas_labels: | ||
| index.name = index_column_pandas_labels[0] | ||
| else: | ||
| # Multi-index case | ||
| index = native_pd.MultiIndex.from_arrays( | ||
| [index_cols.iloc[:, i] for i in range(num_index_columns)], | ||
| names=index_column_pandas_labels | ||
| if index_column_pandas_labels | ||
| else None, | ||
| ) | ||
| data_cols.index = index | ||
| df = data_cols | ||
| else: |
There was a problem hiding this comment.
I meant that you can replace most of the code here with set_index(). See #3979.
| # Determine if we should pass index columns to the UDTF | ||
| # We pass index columns when the index is not the row position itself |
There was a problem hiding this comment.
We always pass the index column names here. We can keep doing that, but we should update the comment and make the parameter required, since there don't seem to be any other invocations of that function.
| column_index: native_pd.Index, | ||
| input_types: list[DataType], | ||
| session: Session, | ||
| index_column_labels: list[Hashable] | None = None, |
There was a problem hiding this comment.
It turns out that just passing the number of index columns is enough:
df.apply(axis=1)should preserve the original index. Previously we would return a RangeIndex regardless of the original index. This approach passes the index data into the underlying UDTF.Mostly AI written approach, but with original tests for verification.
Fixes SNOW-1051741
Fill out the following pre-review checklist: