diff --git a/tests/test_delete.py b/tests/test_delete.py index f767e4d..780d593 100644 --- a/tests/test_delete.py +++ b/tests/test_delete.py @@ -33,8 +33,7 @@ async def test_delete_model_by_id_not_found(db: AsyncSession, crud_ins: CRUDPlus async def test_delete_model_with_flush(db: AsyncSession, sample_ins: list[Ins], crud_ins: CRUDPlus[Ins]): item = sample_ins[1] - async with db.begin(): - count = await crud_ins.delete_model(db, item.id, flush=True) + count = await crud_ins.delete_model(db, item.id, flush=True) assert count == 1 @@ -52,10 +51,15 @@ async def test_delete_model_with_commit(db: AsyncSession, sample_ins: list[Ins], async def test_delete_model_by_column_basic(db: AsyncSession, sample_ins: list[Ins], crud_ins: CRUDPlus[Ins]): item = sample_ins[3] - async with db.begin(): - count = await crud_ins.delete_model_by_column(db, allow_multiple=True, name=item.name) + before_exists = await crud_ins.exists(db, name=item.name) + assert before_exists is True - assert count >= 0 + count = await crud_ins.delete_model_by_column(db, allow_multiple=True, commit=True, name=item.name) + + assert count > 0 + + after_exists = await crud_ins.exists(db, name=item.name) + assert after_exists is False @pytest.mark.asyncio @@ -68,29 +72,45 @@ async def test_delete_model_by_column_not_found(db: AsyncSession, crud_ins: CRUD @pytest.mark.asyncio async def test_delete_model_by_column_allow_multiple(db: AsyncSession, sample_ins: list[Ins], crud_ins: CRUDPlus[Ins]): - async with db.begin(): - count = await crud_ins.delete_model_by_column(db, allow_multiple=True, is_deleted=False) + before_count = await crud_ins.count(db, is_deleted=False) + assert before_count > 0 - assert count >= 0 + count = await crud_ins.delete_model_by_column(db, allow_multiple=True, commit=True, is_deleted=False) + + assert count == before_count + + after_count = await crud_ins.count(db, is_deleted=False) + assert after_count == 0 @pytest.mark.asyncio async def test_delete_model_by_column_with_flush(db: AsyncSession, sample_ins: list[Ins], crud_ins: CRUDPlus[Ins]): item = sample_ins[4] - async with db.begin(): - count = await crud_ins.delete_model_by_column(db, allow_multiple=True, flush=True, name=item.name) + before_exists = await crud_ins.exists(db, name=item.name) + assert before_exists is True - assert count >= 0 + count = await crud_ins.delete_model_by_column(db, allow_multiple=True, commit=True, name=item.name) + + assert count > 0 + + after_exists = await crud_ins.exists(db, name=item.name) + assert after_exists is False @pytest.mark.asyncio async def test_delete_model_by_column_with_commit(db: AsyncSession, sample_ins: list[Ins], crud_ins: CRUDPlus[Ins]): item = sample_ins[5] + before_exists = await crud_ins.exists(db, name=item.name) + assert before_exists is True + count = await crud_ins.delete_model_by_column(db, allow_multiple=True, commit=True, name=item.name) - assert count >= 0 + assert count > 0 + + after_exists = await crud_ins.exists(db, name=item.name) + assert after_exists is False @pytest.mark.asyncio @@ -104,7 +124,9 @@ async def test_delete_model_by_column_no_filters_error(db: AsyncSession, crud_in async def test_delete_model_by_column_multiple_results_error( db: AsyncSession, sample_ins: list[Ins], crud_ins: CRUDPlus[Ins] ): - with pytest.raises(Exception): + from sqlalchemy_crud_plus.errors import MultipleResultsError + + with pytest.raises(MultipleResultsError): async with db.begin(): await crud_ins.delete_model_by_column(db, is_deleted=False) @@ -130,20 +152,22 @@ async def test_logical_delete_single_record(db: AsyncSession, sample_ins: list[I @pytest.mark.asyncio async def test_logical_delete_multiple_records(db: AsyncSession, sample_ins: list[Ins], crud_ins: CRUDPlus[Ins]): - async with db.begin(): - count = await crud_ins.delete_model_by_column( - db, - logical_deletion=True, - deleted_flag_column='is_deleted', - allow_multiple=True, - is_deleted=False, - ) + before_count = await crud_ins.count(db, is_deleted=False) + assert before_count > 0 - assert count >= 0 + count = await crud_ins.delete_model_by_column( + db, + logical_deletion=True, + deleted_flag_column='is_deleted', + allow_multiple=True, + commit=True, + is_deleted=False, + ) - async with db.begin(): - remaining_false = await crud_ins.select_models(db, is_deleted=False) - assert len(remaining_false) >= 0 + assert count == before_count + + remaining_false = await crud_ins.select_models(db, is_deleted=False) + assert len(remaining_false) == 0 @pytest.mark.asyncio @@ -164,37 +188,36 @@ async def test_logical_delete_with_custom_column(db: AsyncSession, sample_ins: l @pytest.mark.asyncio async def test_logical_delete_with_filters(db: AsyncSession, sample_ins: list[Ins], crud_ins: CRUDPlus[Ins]): - async with db.begin(): - count = await crud_ins.delete_model_by_column( - db, - logical_deletion=True, - deleted_flag_column='is_deleted', - allow_multiple=True, - name__like='item_%', - id__gt=3, - ) + await crud_ins.count(db, name__like='item_%', id__gt=3, is_deleted=False) - assert count >= 0 + count = await crud_ins.delete_model_by_column( + db, + logical_deletion=True, + deleted_flag_column='is_deleted', + allow_multiple=True, + commit=True, + name__like='item_%', + id__gt=3, + ) - async with db.begin(): - deleted_items = await crud_ins.select_models(db, name__like='item_%', id__gt=3, is_deleted=True) + assert count >= 0 - assert len(deleted_items) >= 0 + deleted_items = await crud_ins.select_models(db, name__like='item_%', id__gt=3, is_deleted=True) + assert len(deleted_items) >= count @pytest.mark.asyncio async def test_logical_delete_with_flush(db: AsyncSession, sample_ins: list[Ins], crud_ins: CRUDPlus[Ins]): item = sample_ins[2] - async with db.begin(): - count = await crud_ins.delete_model_by_column( - db, - logical_deletion=True, - deleted_flag_column='is_deleted', - allow_multiple=False, - flush=True, - id=item.id, - ) + count = await crud_ins.delete_model_by_column( + db, + logical_deletion=True, + deleted_flag_column='is_deleted', + allow_multiple=False, + flush=True, + id=item.id, + ) assert count == 1 @@ -234,20 +257,23 @@ async def test_logical_delete_no_matching_records(db: AsyncSession, crud_ins: CR @pytest.mark.asyncio async def test_logical_delete_already_deleted_records(db: AsyncSession, sample_ins: list[Ins], crud_ins: CRUDPlus[Ins]): - async with db.begin(): - await crud_ins.delete_model_by_column( - db, logical_deletion=True, deleted_flag_column='is_deleted', allow_multiple=True, id__le=3 - ) + first_count = await crud_ins.delete_model_by_column( + db, logical_deletion=True, deleted_flag_column='is_deleted', allow_multiple=True, commit=True, id__le=3 + ) + assert first_count >= 0 - async with db.begin(): - count = await crud_ins.delete_model_by_column( - db, - logical_deletion=True, - deleted_flag_column='is_deleted', - allow_multiple=True, - id__le=3, - is_deleted=True, - ) + already_deleted_count = await crud_ins.count(db, id__le=3, is_deleted=True) + assert already_deleted_count >= first_count + + count = await crud_ins.delete_model_by_column( + db, + logical_deletion=True, + deleted_flag_column='is_deleted', + allow_multiple=True, + commit=True, + id__le=3, + is_deleted=True, + ) assert count >= 0 @@ -295,6 +321,8 @@ async def test_logical_delete_single_but_multiple_found( async def test_logical_delete_affects_count(db: AsyncSession, sample_ins: list[Ins], crud_ins: CRUDPlus[Ins]): initial_count = await crud_ins.count(db, is_deleted=False) + to_delete_count = await crud_ins.count(db, id__le=2, is_deleted=False) + deleted_count = await crud_ins.delete_model_by_column( db, logical_deletion=True, @@ -302,16 +330,13 @@ async def test_logical_delete_affects_count(db: AsyncSession, sample_ins: list[I allow_multiple=True, commit=True, id__le=2, + is_deleted=False, ) final_count = await crud_ins.count(db, is_deleted=False) - # 检查是否至少删除了一条记录 - assert deleted_count >= 0 - - # 如果删除了记录,则最终计数应该小于或等于初始计数 - if deleted_count > 0: - assert final_count <= initial_count + assert deleted_count == to_delete_count + assert final_count == initial_count - deleted_count @pytest.mark.asyncio @@ -359,7 +384,6 @@ async def test_logical_delete_with_select_models(db: AsyncSession, sample_ins: l @pytest.mark.asyncio async def test_delete_model_by_column_with_deleted_at(db: AsyncSession, sample_ins: list[Ins], crud_ins: CRUDPlus[Ins]): - # 使用数据库中存在的项目 item = sample_ins[6] async with db.begin(): @@ -378,7 +402,6 @@ async def test_delete_model_by_column_with_deleted_at(db: AsyncSession, sample_i async def test_delete_model_by_column_without_deleted_at_column( db: AsyncSession, sample_ins: list[Ins], crud_ins: CRUDPlus[Ins] ): - # 使用数据库中存在的项目 item = sample_ins[7] async with db.begin(): @@ -390,3 +413,30 @@ async def test_delete_model_by_column_without_deleted_at_column( updated_item = await crud_ins.select_model(db, item.id) assert updated_item is not None assert updated_item.is_deleted is True + + +@pytest.mark.asyncio +async def test_delete_model_by_column_with_custom_deleted_at_column( + db: AsyncSession, sample_ins: list[Ins], crud_ins: CRUDPlus[Ins] +): + from datetime import datetime, timezone + + item = sample_ins[8] + + async with db.begin(): + count = await crud_ins.delete_model_by_column( + db, + logical_deletion=True, + deleted_flag_column='is_deleted', + deleted_at_column='updated_time', + deleted_at_factory=datetime.now(timezone.utc), + id=item.id, + ) + + assert count == 1 + + async with db.begin(): + updated_item = await crud_ins.select_model(db, item.id) + assert updated_item is not None + assert updated_item.is_deleted is True + assert updated_item.updated_time is not None diff --git a/tests/test_filters.py b/tests/test_filters.py index e89ab48..cb1e809 100644 --- a/tests/test_filters.py +++ b/tests/test_filters.py @@ -113,21 +113,23 @@ async def test_filter_like(db: AsyncSession, sample_ins: list[Ins], crud_ins: CR async def test_filter_not_like(db: AsyncSession, sample_ins: list[Ins], crud_ins: CRUDPlus[Ins]): results = await crud_ins.select_models(db, name__not_like='nonexistent_%') - assert len(results) >= 0 + assert len(results) > 0 + assert all('nonexistent_' not in r.name for r in results) @pytest.mark.asyncio async def test_filter_ilike(db: AsyncSession, sample_ins: list[Ins], crud_ins: CRUDPlus[Ins]): results = await crud_ins.select_models(db, name__ilike='ITEM_%') - assert len(results) >= 0 + assert len(results) > 0 + assert all(r.name.lower().startswith('item_') for r in results) @pytest.mark.asyncio async def test_filter_not_ilike(db: AsyncSession, sample_ins: list[Ins], crud_ins: CRUDPlus[Ins]): results = await crud_ins.select_models(db, name__not_ilike='ITEM_%') - assert len(results) >= 0 + assert all(not r.name.lower().startswith('item_') for r in results) @pytest.mark.asyncio @@ -156,22 +158,22 @@ async def test_filter_match(db: AsyncSession, sample_ins: list[Ins], crud_ins: C try: results = await crud_ins.select_models(db, name__match='item') assert len(results) >= 0 - except Exception: - assert True + except Exception as e: + assert 'match' in str(e).lower() or 'not supported' in str(e).lower() @pytest.mark.asyncio async def test_filter_concat(db: AsyncSession, sample_ins: list[Ins], crud_ins: CRUDPlus[Ins]): results = await crud_ins.select_models(db, name__concat='_test') - assert len(results) >= 0 + assert isinstance(results, list) @pytest.mark.asyncio async def test_filter_add(db: AsyncSession, sample_ins: list[Ins], crud_ins: CRUDPlus[Ins]): results = await crud_ins.select_models(db, id__add=1) - assert len(results) >= 0 + assert isinstance(results, list) @pytest.mark.asyncio @@ -202,6 +204,13 @@ async def test_filter_mul(db: AsyncSession, sample_ins: list[Ins], crud_ins: CRU assert len(results) >= 0 +@pytest.mark.asyncio +async def test_filter_mul_with_condition(db: AsyncSession, sample_ins: list[Ins], crud_ins: CRUDPlus[Ins]): + results = await crud_ins.select_models(db, id__mul={'value': 2, 'condition': {'gt': 0}}) + + assert len(results) >= 0 + + @pytest.mark.asyncio async def test_filter_rmul(db: AsyncSession, sample_ins: list[Ins], crud_ins: CRUDPlus[Ins]): results = await crud_ins.select_models(db, id__rmul=3) @@ -255,7 +264,9 @@ async def test_filter_rmod(db: AsyncSession, sample_ins: list[Ins], crud_ins: CR async def test_filter_or_same_field_list_values(db: AsyncSession, sample_ins: list[Ins], crud_ins: CRUDPlus[Ins]): results = await crud_ins.select_models(db, __or__={'is_deleted': [True, False]}) - assert len(results) >= 0 + total_count = await crud_ins.count(db) + assert len(results) == total_count + assert all(r.is_deleted in [True, False] for r in results) @pytest.mark.asyncio @@ -264,14 +275,16 @@ async def test_filter_or_different_fields_single_values( ): results = await crud_ins.select_models(db, __or__={'is_deleted': True, 'id__gt': 5}) - assert len(results) >= 0 + assert len(results) > 0 + assert all(r.is_deleted is True or r.id > 5 for r in results) @pytest.mark.asyncio async def test_filter_or_with_operators(db: AsyncSession, sample_ins: list[Ins], crud_ins: CRUDPlus[Ins]): results = await crud_ins.select_models(db, __or__={'name__like': 'item_%', 'id__lt': 3}) - assert len(results) >= 0 + assert len(results) > 0 + assert all('item_' in r.name or r.id < 3 for r in results) @pytest.mark.asyncio diff --git a/tests/test_no_relationship.py b/tests/test_no_relationship.py index c8c2ebe..47dc4a6 100644 --- a/tests/test_no_relationship.py +++ b/tests/test_no_relationship.py @@ -2,7 +2,6 @@ # -*- coding: utf-8 -*- import pytest -from sqlalchemy import select from sqlalchemy.engine.row import Row from sqlalchemy.ext.asyncio import AsyncSession @@ -188,29 +187,6 @@ async def test_exists_with_join_condition( assert isinstance(exists, bool) -@pytest.mark.asyncio -async def test_join_get_user_posts_grouped(db: AsyncSession, no_rel_sample_data: dict): - stmt = select(NoRelUser, NoRelPost).join( - NoRelPost, - NoRelUser.id == NoRelPost.author_id, - isouter=True, - ) - result = await db.execute(stmt) - rows = result.all() - - user_posts = {} - for user, post in rows: - if user.id not in user_posts: - user_posts[user.id] = {'user': user, 'posts': []} - if post: - user_posts[user.id]['posts'].append(post) - - assert len(user_posts) >= 1 - for user_id, data in user_posts.items(): - assert isinstance(data['user'], NoRelUser) - assert isinstance(data['posts'], list) - - @pytest.mark.asyncio async def test_combined_join_filter_order( db: AsyncSession, no_rel_sample_data: dict, no_rel_crud_post: CRUDPlus[NoRelPost] @@ -335,103 +311,6 @@ async def test_join_filter_with_join_conditions( assert len(users) >= 1 assert all(isinstance(user, NoRelUser) for user in users) - stmt = select(NoRelUser, NoRelPost).join( - NoRelPost, - NoRelUser.id == NoRelPost.author_id, - ) - result = await db.execute(stmt) - user_post_pairs = result.all() - - assert len(user_post_pairs) >= 1 - for user, post in user_post_pairs: - assert isinstance(user, NoRelUser) - assert isinstance(post, NoRelPost) - assert user.id == post.author_id - - -@pytest.mark.asyncio -async def test_join_get_multiple_table_data(db: AsyncSession, no_rel_sample_data: dict): - stmt = select(NoRelUser, NoRelProfile).join( - NoRelProfile, - NoRelUser.id == NoRelProfile.user_id, - isouter=True, - ) - result = await db.execute(stmt) - rows = result.all() - - assert len(rows) >= 2 - for user, profile in rows: - assert isinstance(user, NoRelUser) - assert user.name is not None - if profile: - assert isinstance(profile, NoRelProfile) - assert profile.user_id == user.id - - -@pytest.mark.asyncio -async def test_join_get_post_with_author_data(db: AsyncSession, no_rel_sample_data: dict): - stmt = select(NoRelPost, NoRelUser).join( - NoRelUser, - NoRelPost.author_id == NoRelUser.id, - ) - result = await db.execute(stmt) - rows = result.all() - - assert len(rows) >= 1 - for post, user in rows: - assert isinstance(post, NoRelPost) - assert isinstance(user, NoRelUser) - assert post.author_id == user.id - assert post.title is not None - assert user.name is not None - - -@pytest.mark.asyncio -async def test_join_get_three_table_data(db: AsyncSession, no_rel_sample_data: dict): - stmt = ( - select(NoRelPost, NoRelUser, NoRelCategory) - .join(NoRelUser, NoRelPost.author_id == NoRelUser.id) - .join(NoRelCategory, NoRelPost.category_id == NoRelCategory.id, isouter=True) - ) - result = await db.execute(stmt) - rows = result.all() - - assert len(rows) >= 1 - for post, user, category in rows: - assert isinstance(post, NoRelPost) - assert isinstance(user, NoRelUser) - assert post.author_id == user.id - if category: - assert isinstance(category, NoRelCategory) - assert post.category_id == category.id - - -@pytest.mark.asyncio -async def test_join_build_dict_result(db: AsyncSession, no_rel_sample_data: dict): - stmt = select(NoRelUser, NoRelProfile).join( - NoRelProfile, - NoRelUser.id == NoRelProfile.user_id, - ) - result = await db.execute(stmt) - rows = result.all() - - data = [ - { - 'user_id': user.id, - 'user_name': user.name, - 'profile_bio': profile.bio, - } - for user, profile in rows - ] - - assert len(data) >= 1 - for item in data: - assert 'user_id' in item - assert 'user_name' in item - assert 'profile_bio' in item - assert isinstance(item['user_id'], int) - assert isinstance(item['user_name'], str) - @pytest.mark.asyncio async def test_join_with_fill_result_true( @@ -516,9 +395,10 @@ async def test_join_multiple_with_fill_result( async def test_join_fill_result_single_model( db: AsyncSession, no_rel_sample_data: dict, no_rel_crud_user: CRUDPlus[NoRelUser] ): + user_id = no_rel_sample_data['users'][0].id result = await no_rel_crud_user.select_model( db, - no_rel_sample_data['users'][0].id, + user_id, join_conditions=[ JoinConfig( model=NoRelProfile, diff --git a/tests/test_select.py b/tests/test_select.py index 5bff019..767a42a 100644 --- a/tests/test_select.py +++ b/tests/test_select.py @@ -2,10 +2,13 @@ # -*- coding: utf-8 -*- import pytest +from sqlalchemy.engine.row import Row from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy_crud_plus import CRUDPlus +from sqlalchemy_crud_plus.types import JoinConfig from tests.models.basic import Ins +from tests.models.no_relationship import NoRelProfile, NoRelUser @pytest.mark.asyncio @@ -255,3 +258,81 @@ async def test_exists_with_kwargs(db: AsyncSession, sample_ins: list[Ins], crud_ exists = await crud_ins.exists(db, is_deleted=False) assert isinstance(exists, bool) + + +@pytest.mark.asyncio +async def test_select_model_with_fill_result(db: AsyncSession, no_rel_sample_data: dict): + crud_user = CRUDPlus(NoRelUser) + user = no_rel_sample_data['users'][0] + + result = await crud_user.select_model( + db, + user.id, + join_conditions=[ + JoinConfig( + model=NoRelProfile, + join_on=NoRelUser.id == NoRelProfile.user_id, + join_type='left', + fill_result=True, + ) + ], + ) + + assert result is not None + assert isinstance(result, (tuple, Row)) + assert len(result) == 2 + assert isinstance(result[0], NoRelUser) + if result[1]: + assert isinstance(result[1], NoRelProfile) + + +@pytest.mark.asyncio +async def test_select_model_by_column_with_fill_result(db: AsyncSession, no_rel_sample_data: dict): + crud_user = CRUDPlus(NoRelUser) + user = no_rel_sample_data['users'][0] + + result = await crud_user.select_model_by_column( + db, + name=user.name, + join_conditions=[ + JoinConfig( + model=NoRelProfile, + join_on=NoRelUser.id == NoRelProfile.user_id, + join_type='left', + fill_result=True, + ) + ], + ) + + assert result is not None + assert isinstance(result, (tuple, Row)) + assert len(result) == 2 + assert isinstance(result[0], NoRelUser) + if result[1]: + assert isinstance(result[1], NoRelProfile) + + +@pytest.mark.asyncio +async def test_select_models_order_with_fill_result(db: AsyncSession, no_rel_sample_data: dict): + crud_user = CRUDPlus(NoRelUser) + + results = await crud_user.select_models_order( + db, + 'name', + join_conditions=[ + JoinConfig( + model=NoRelProfile, + join_on=NoRelUser.id == NoRelProfile.user_id, + join_type='left', + fill_result=True, + ) + ], + ) + + assert len(results) >= 1 + for result in results: + assert isinstance(result, (tuple, Row)) + assert len(result) == 2 + assert isinstance(result[0], NoRelUser) + if result[1]: + assert isinstance(result[1], NoRelProfile) diff --git a/tests/test_update.py b/tests/test_update.py index 452c760..1417c6d 100644 --- a/tests/test_update.py +++ b/tests/test_update.py @@ -35,8 +35,7 @@ async def test_update_model_with_flush(db: AsyncSession, sample_ins: list[Ins], item = sample_ins[0] update_data = UpdateIns(name='flush_update') - async with db.begin(): - result = await crud_ins.update_model(db, item.id, update_data, flush=True) + result = await crud_ins.update_model(db, item.id, update_data, flush=True) assert result == 1 @@ -98,10 +97,15 @@ async def test_update_model_by_column_not_found(db: AsyncSession, crud_ins: CRUD async def test_update_model_by_column_allow_multiple(db: AsyncSession, sample_ins: list[Ins], crud_ins: CRUDPlus[Ins]): update_data = UpdateIns(name='multiple_update') - async with db.begin(): - result = await crud_ins.update_model_by_column(db, update_data, allow_multiple=True, is_deleted=False) + before_count = await crud_ins.count(db, is_deleted=False) + assert before_count > 0 + + result = await crud_ins.update_model_by_column(db, update_data, allow_multiple=True, commit=True, is_deleted=False) - assert result >= 0 + assert result == before_count + + updated_records = await crud_ins.select_models(db, name='multiple_update') + assert len(updated_records) == before_count @pytest.mark.asyncio @@ -109,8 +113,7 @@ async def test_update_model_by_column_with_flush(db: AsyncSession, sample_ins: l item = sample_ins[0] update_data = UpdateIns(name='flush_column_update') - async with db.begin(): - result = await crud_ins.update_model_by_column(db, update_data, flush=True, id=item.id) + result = await crud_ins.update_model_by_column(db, update_data, flush=True, id=item.id) assert result == 1 @@ -149,9 +152,11 @@ async def test_update_model_by_column_no_filters_error(db: AsyncSession, crud_in async def test_update_model_by_column_multiple_results_error( db: AsyncSession, sample_ins: list[Ins], crud_ins: CRUDPlus[Ins] ): + from sqlalchemy_crud_plus.errors import MultipleResultsError + update_data = UpdateIns(name='multiple_error') - with pytest.raises(Exception): + with pytest.raises(MultipleResultsError): async with db.begin(): await crud_ins.update_model_by_column(db, update_data, is_deleted=False) @@ -229,7 +234,6 @@ async def test_bulk_update_models_with_pydantic_schema(db: AsyncSession, crud_in @pytest.mark.asyncio async def test_bulk_update_models_pk_mode_false_no_filters_error(db: AsyncSession, crud_ins: CRUDPlus[Ins]): - """测试 bulk_update_models pk_mode=False 时没有过滤条件的错误""" update_data = [{'name': 'no_filters'}] with pytest.raises(ValueError, match='At least one filter condition must be provided'): @@ -277,10 +281,9 @@ async def test_bulk_update_models_pk_mode_false_with_flush(db: AsyncSession, cru update_data = [{'name': 'updated_flush_1'}, {'name': 'updated_flush_2'}] - async with db.begin(): - result = await crud_ins.bulk_update_models( - db, update_data, pk_mode=False, flush=True, name__like='bulk_update_flush_%' - ) + result = await crud_ins.bulk_update_models( + db, update_data, pk_mode=False, flush=True, name__like='bulk_update_flush_%' + ) assert result == 2 diff --git a/tests/test_utils.py b/tests/test_utils.py index 4c0da87..8502681 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -102,14 +102,38 @@ def test_valid_columns_name(self): assert column.name == 'name' def test_invalid_column(self): - with pytest.raises(ModelColumnError): + with pytest.raises(ModelColumnError) as exc_info: get_column(Ins, 'nonexistent_column') + assert str(exc_info.value) def test_aliased_model(self): aliased_ins = aliased(Ins) column = get_column(aliased_ins, 'name') assert column is not None + def test_invalid_column_property(self): + from sqlalchemy import ForeignKey + from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship + + class Base(DeclarativeBase): + pass + + class TestModel(Base): + __tablename__ = 'test_model' + id: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] = mapped_column() + + class RelatedModel(Base): + __tablename__ = 'related_model' + id: Mapped[int] = mapped_column(primary_key=True) + test_id: Mapped[int] = mapped_column(ForeignKey('test_model.id')) + + TestModel.related = relationship(RelatedModel) + + with pytest.raises(ModelColumnError) as exc_info: + get_column(TestModel, 'related') + assert str(exc_info.value) + class TestParseFilters: def test_basic_filters(self):