diff --git a/src/openedx_content/applets/publishing/api.py b/src/openedx_content/applets/publishing/api.py index 14a71e418..5e7e7f867 100644 --- a/src/openedx_content/applets/publishing/api.py +++ b/src/openedx_content/applets/publishing/api.py @@ -10,7 +10,7 @@ from contextlib import nullcontext from datetime import datetime, timezone from functools import partial -from typing import ContextManager, Optional, cast +from typing import ContextManager, Optional, Sequence, cast from django.contrib.auth import get_user_model from django.core.exceptions import ObjectDoesNotExist @@ -76,10 +76,12 @@ ] -def get_learning_package(learning_package_id: LearningPackage.ID, /) -> LearningPackage: +def get_learning_package(learning_package_id: LearningPackage | LearningPackage.ID, /) -> LearningPackage: """ Get LearningPackage by ID. """ + if isinstance(learning_package_id, LearningPackage): + return learning_package_id return LearningPackage.objects.get(id=learning_package_id) @@ -129,7 +131,7 @@ def send_event(): def update_learning_package( - learning_package_id: LearningPackage.ID, + learning_package: LearningPackage | LearningPackage.ID, /, package_ref: str | None = None, title: str | None = None, @@ -141,34 +143,34 @@ def update_learning_package( Note that LearningPackage itself is not versioned (only stuff inside it is). """ - lp = LearningPackage.objects.get(id=learning_package_id) + learning_package = get_learning_package(learning_package) # If no changes were requested, there's nothing to update, so just return # the LearningPackage as-is. if all(field is None for field in [package_ref, title, description, updated]): - return lp + return learning_package if package_ref is not None: - lp.package_ref = package_ref + learning_package.package_ref = package_ref if title is not None: - lp.title = title + learning_package.title = title if description is not None: - lp.description = description + learning_package.description = description # updated is a bit different–we auto-generate it if it's not explicitly # passed in. if updated is None: updated = datetime.now(tz=timezone.utc) - lp.updated = updated + learning_package.updated = updated - lp.save() + learning_package.save() # Emit LEARNING_PACKAGE_UPDATED once the transaction commits. Note: we only # reach this point if at least one of key/title/description/updated was # passed in (the early-return above handles the no-op case), so the update # really did touch the row. - lp_id = lp.id - lp_title = lp.title + lp_id = learning_package.id + lp_title = learning_package.title def send_event(): signals.LEARNING_PACKAGE_UPDATED.send_event( @@ -176,8 +178,7 @@ def send_event(): ) on_commit(send_event) - - return lp + return learning_package def learning_package_exists(package_ref: str) -> bool: @@ -188,7 +189,7 @@ def learning_package_exists(package_ref: str) -> bool: def create_publishable_entity( - learning_package_id: LearningPackage.ID, + learning_package_id: LearningPackage | LearningPackage.ID, /, entity_ref: str, created: datetime, @@ -203,6 +204,10 @@ def create_publishable_entity( You'd typically want to call this right before creating your own content model that points to it. """ + learning_package_id = ( + learning_package_id.id if isinstance(learning_package_id, LearningPackage) + else learning_package_id + ) return PublishableEntity.objects.create( learning_package_id=learning_package_id, entity_ref=entity_ref, @@ -213,14 +218,14 @@ def create_publishable_entity( def create_publishable_entity_version( - entity_id: PublishableEntity.ID, + entity_id: PublishableEntity | PublishableEntity.ID, /, version_num: int, title: str, created: datetime, created_by: int | None, *, - dependencies: list[PublishableEntity.ID] | None = None, + dependencies: Sequence[PublishableEntity.ID] | None = None, ) -> PublishableEntityVersion: """ Create a PublishableEntityVersion. @@ -228,6 +233,10 @@ def create_publishable_entity_version( You'd typically want to call this right before creating your own content version model that points to it. """ + entity_id = ( + entity_id.id if isinstance(entity_id, PublishableEntity) + else entity_id + ) with atomic(savepoint=False): version = PublishableEntityVersion.objects.create( entity_id=entity_id, @@ -249,9 +258,9 @@ def create_publishable_entity_version( def set_version_dependencies( - version_id: int, # PublishableEntityVersion.id, + version_id: PublishableEntityVersion | int, /, - dependencies: list[PublishableEntity.ID], + dependencies: Sequence[PublishableEntity.ID], ) -> None: """ Set the dependencies of a publishable entity version. @@ -301,6 +310,10 @@ def set_version_dependencies( will cause recalculation to the higher levels that depend on it. 3. Do not create circular dependencies. """ + version_id = ( + version_id.id if isinstance(version_id, PublishableEntityVersion) + else version_id + ) PublishableEntityVersionDependency.objects.bulk_create( [ PublishableEntityVersionDependency( @@ -312,18 +325,40 @@ def set_version_dependencies( ) -def get_publishable_entity(publishable_entity_id: PublishableEntity.ID, /) -> PublishableEntity: - return PublishableEntity.objects.get(pk=publishable_entity_id) +def get_publishable_entity(entity_id: PublishableEntity | PublishableEntity.ID, /) -> PublishableEntity: + """ + Get a learning package by its database primary key. + """ + if isinstance(entity_id, PublishableEntity): + return entity_id + return PublishableEntity.objects.get(id=entity_id) -def get_publishable_entity_by_ref(learning_package_id: LearningPackage.ID, /, entity_ref: str) -> PublishableEntity: +def get_publishable_entity_by_ref( + learning_package_id: LearningPackage | LearningPackage.ID, /, + entity_ref: str, +) -> PublishableEntity: + """ + Given a learning package and an entity_ref, get the matching publishable entity. + """ + learning_package_id = ( + learning_package_id.id if isinstance(learning_package_id, LearningPackage) + else learning_package_id + ) return PublishableEntity.objects.get( learning_package_id=learning_package_id, entity_ref=entity_ref, ) -def get_last_publish(learning_package_id: LearningPackage.ID, /) -> PublishLog | None: +def get_last_publish(learning_package_id: LearningPackage | LearningPackage.ID, /) -> PublishLog | None: + """ + Get the log of the latest publish in the LearningPackage, or None if nothing has been published. + """ + learning_package_id = ( + learning_package_id.id if isinstance(learning_package_id, LearningPackage) + else learning_package_id + ) return PublishLog.objects \ .filter(learning_package_id=learning_package_id) \ .order_by('-id') \ @@ -337,10 +372,16 @@ def get_all_drafts(learning_package_id: LearningPackage.ID, /) -> QuerySet[Draft ) -def get_publishable_entities(learning_package_id: LearningPackage.ID, /) -> QuerySet[PublishableEntity]: +def get_publishable_entities( + learning_package_id: LearningPackage | LearningPackage.ID, / +) -> QuerySet[PublishableEntity]: """ Get all entities in a learning package. """ + learning_package_id = ( + learning_package_id.id if isinstance(learning_package_id, LearningPackage) + else learning_package_id + ) return ( PublishableEntity.objects .filter(learning_package_id=learning_package_id) @@ -352,7 +393,7 @@ def get_publishable_entities(learning_package_id: LearningPackage.ID, /) -> Quer def get_entities_with_unpublished_changes( - learning_package_id: LearningPackage.ID, + learning_package_id: LearningPackage | LearningPackage.ID, /, include_deleted_drafts: bool = False ) -> QuerySet[PublishableEntity]: @@ -362,6 +403,10 @@ def get_entities_with_unpublished_changes( By default, this excludes soft-deleted drafts but can be included using include_deleted_drafts option. """ + learning_package_id = ( + learning_package_id.id if isinstance(learning_package_id, LearningPackage) + else learning_package_id + ) entities_qs = ( PublishableEntity.objects .filter(learning_package_id=learning_package_id) @@ -384,12 +429,18 @@ def get_entities_with_unpublished_changes( return entities_qs.exclude(draft__version__isnull=True) -def get_entities_with_unpublished_deletes(learning_package_id: LearningPackage.ID, /) -> QuerySet[PublishableEntity]: +def get_entities_with_unpublished_deletes( + learning_package_id: LearningPackage | LearningPackage.ID, / +) -> QuerySet[PublishableEntity]: """ Something will become "deleted" if it has a null Draft version but a not-null Published version. (If both are null, it means it's already been deleted in a previous publish, or it was never published.) """ + learning_package_id = ( + learning_package_id.id if isinstance(learning_package_id, LearningPackage) + else learning_package_id + ) return PublishableEntity.objects \ .filter( learning_package_id=learning_package_id, @@ -398,7 +449,7 @@ def get_entities_with_unpublished_deletes(learning_package_id: LearningPackage.I def publish_all_drafts( - learning_package_id: LearningPackage.ID, + learning_package_id: LearningPackage | LearningPackage.ID, /, message="", published_at: datetime | None = None, @@ -407,6 +458,10 @@ def publish_all_drafts( """ Publish everything that is a Draft and is not already published. """ + learning_package_id = ( + learning_package_id.id if isinstance(learning_package_id, LearningPackage) + else learning_package_id + ) draft_qset = ( Draft.objects .filter(entity__learning_package_id=learning_package_id) @@ -451,7 +506,7 @@ def _get_dependencies_with_unpublished_changes( def publish_from_drafts( - learning_package_id: LearningPackage.ID, + learning_package_id: LearningPackage | LearningPackage.ID, /, draft_qset: QuerySet[Draft], message: str = "", @@ -466,6 +521,10 @@ def publish_from_drafts( By default, this will also publish all dependencies (e.g. unpinned children) of the Drafts that are passed in. """ + learning_package_id = ( + learning_package_id.id if isinstance(learning_package_id, LearningPackage) + else learning_package_id + ) if published_at is None: published_at = datetime.now(tz=timezone.utc) @@ -598,7 +657,7 @@ def get_published_version( def get_entity_draft_history( - publishable_entity_or_id: PublishableEntity | int, / + entity_id: PublishableEntity | PublishableEntity.ID, / ) -> QuerySet[DraftChangeLogRecord]: """ [ 🛑 UNSTABLE ] @@ -617,11 +676,10 @@ def get_entity_draft_history( soft-delete DraftChangeLogRecord (new_version=None) is included because it was made after the last real publish. """ - if isinstance(publishable_entity_or_id, int): - entity_id = PublishableEntity.PublishableEntityID(publishable_entity_or_id) - else: - entity_id = publishable_entity_or_id.id - + entity_id = ( + entity_id.id if isinstance(entity_id, PublishableEntity) + else entity_id + ) qs = ( DraftChangeLogRecord.objects .filter(entity_id=entity_id) @@ -666,7 +724,7 @@ def get_entity_draft_history( def get_entity_publish_history( - publishable_entity_or_id: PublishableEntity | int, / + entity_id: PublishableEntity | PublishableEntity.ID, / ) -> QuerySet[PublishLogRecord]: """ [ 🛑 UNSTABLE ] @@ -681,11 +739,10 @@ def get_entity_publish_history( PublishLogRecord captures only the version that was actually published, not the intermediate draft versions. """ - if isinstance(publishable_entity_or_id, int): - entity_id = PublishableEntity.PublishableEntityID(publishable_entity_or_id) - else: - entity_id = publishable_entity_or_id.id - + entity_id = ( + entity_id.id if isinstance(entity_id, PublishableEntity) + else entity_id + ) return ( PublishLogRecord.objects .filter(entity_id=entity_id) @@ -699,7 +756,7 @@ def get_entity_publish_history( def get_entity_publish_history_entries( - publishable_entity_or_id: PublishableEntity | int, + entity_id: PublishableEntity | PublishableEntity.ID, /, publish_log_uuid: str, ) -> QuerySet[DraftChangeLogRecord]: @@ -728,10 +785,10 @@ def get_entity_publish_history_entries( Raises PublishLogRecord.DoesNotExist if publish_log_uuid is not found for this entity. """ - if isinstance(publishable_entity_or_id, int): - entity_id = PublishableEntity.PublishableEntityID(publishable_entity_or_id) - else: - entity_id = publishable_entity_or_id.id + entity_id = ( + entity_id.id if isinstance(entity_id, PublishableEntity) + else entity_id + ) # Fetch the PublishLogRecord for the requested PublishLog pub_record = ( @@ -796,7 +853,7 @@ def get_entity_publish_history_entries( def get_entity_version_contributors( - publishable_entity_or_id: PublishableEntity | int, + entity_id: PublishableEntity | PublishableEntity.ID, /, old_version_num: int, new_version_num: int | None, @@ -821,10 +878,10 @@ def get_entity_version_contributors( - A user who contributed multiple versions in the range appears only once (results are deduplicated with DISTINCT). """ - if isinstance(publishable_entity_or_id, int): - entity_id = PublishableEntity.PublishableEntityID(publishable_entity_or_id) - else: - entity_id = publishable_entity_or_id.id + entity_id = ( + entity_id.id if isinstance(entity_id, PublishableEntity) + else entity_id + ) if new_version_num is not None: version_filter = Q( @@ -866,8 +923,8 @@ def get_entity_version_contributors( def set_draft_version( - draft_or_id: Draft | PublishableEntity.ID, - publishable_entity_version_pk: int | None, + draft: Draft | PublishableEntity.ID, + version_id: PublishableEntityVersion | int | None, /, set_at: datetime | None = None, set_by: int | None = None, # User.id @@ -897,27 +954,29 @@ def set_draft_version( """ if set_at is None: set_at = datetime.now(tz=timezone.utc) + version_id = ( + version_id.id if isinstance(version_id, PublishableEntityVersion) + else version_id + ) with atomic(savepoint=False): - if isinstance(draft_or_id, Draft): - draft = draft_or_id - elif isinstance(draft_or_id, int): + if isinstance(draft, int): draft, _created = Draft.objects.select_related("entity") \ - .get_or_create(entity_id=draft_or_id) - else: - class_name = draft_or_id.__class__.__name__ + .get_or_create(entity_id=draft) + elif not isinstance(draft, Draft): + class_name = draft.__class__.__name__ raise TypeError( f"draft_or_id must be a Draft or int, not ({class_name})" ) # If the Draft is already pointing at this version, there's nothing to do. old_version_id = draft.version_id - if old_version_id == publishable_entity_version_pk: + if old_version_id == version_id: return # The actual update of the Draft model is here. Everything after this # block is bookkeeping in our DraftChangeLog. - draft.version_id = publishable_entity_version_pk + draft.version_id = version_id # Check to see if we're inside a context manager for an active # DraftChangeLog (i.e. what happens if the caller is using the public @@ -931,7 +990,7 @@ def set_draft_version( active_change_log, draft.entity_id, old_version_id=old_version_id, - new_version_id=publishable_entity_version_pk, + new_version_id=version_id, ) if draft_log_record: # Normal case: a DraftChangeLogRecord was created or updated. @@ -984,7 +1043,7 @@ def set_draft_version( draft_change_log=change_log, entity_id=draft.entity_id, old_version_id=old_version_id, - new_version_id=publishable_entity_version_pk, + new_version_id=version_id, ) draft.save() _create_side_effects_for_change_log(change_log) @@ -994,7 +1053,7 @@ def set_draft_version( def _add_to_existing_draft_change_log( active_change_log: DraftChangeLog, - entity_id: PublishableEntity.ID, + entity_id: PublishableEntity | PublishableEntity.ID, old_version_id: int | None, new_version_id: int | None, ) -> DraftChangeLogRecord | None: @@ -1021,6 +1080,10 @@ def _add_to_existing_draft_change_log( log records (the only place where it's normal to have the same old and new versions). """ + entity_id = ( + entity_id.id if isinstance(entity_id, PublishableEntity) + else entity_id + ) try: # Check to see if this PublishableEntity has already been changed in # this DraftChangeLog. If so, we update that record instead of creating @@ -1523,7 +1586,9 @@ def hash_for_log_record( return digest -def soft_delete_draft(publishable_entity_id: PublishableEntity.ID, /, deleted_by: int | None = None) -> None: +def soft_delete_draft( + entity_id: PublishableEntity | PublishableEntity.ID, /, deleted_by: int | None = None +) -> None: """ Sets the Draft version to None. @@ -1533,11 +1598,15 @@ def soft_delete_draft(publishable_entity_id: PublishableEntity.ID, /, deleted_by of pointing the Draft back to the most recent ``PublishableEntityVersion`` for a given ``PublishableEntity``. """ - return set_draft_version(publishable_entity_id, None, set_by=deleted_by) + entity_id = ( + entity_id.id if isinstance(entity_id, PublishableEntity) + else entity_id + ) + return set_draft_version(entity_id, None, set_by=deleted_by) def reset_drafts_to_published( - learning_package_id: LearningPackage.ID, + learning_package_id: LearningPackage | LearningPackage.ID, /, reset_at: datetime | None = None, reset_by: int | None = None, # User.id @@ -1566,6 +1635,10 @@ def reset_drafts_to_published( latest version created for a PublishableEntity (its ``latest`` attribute), rather than basing it off of the version that Draft points to. """ + learning_package_id = ( + learning_package_id.id if isinstance(learning_package_id, LearningPackage) + else learning_package_id + ) if reset_at is None: reset_at = datetime.now(tz=timezone.utc) @@ -1665,7 +1738,7 @@ def filter_publishable_entities( def get_published_version_as_of( - entity_id: PublishableEntity.ID, + entity_id: PublishableEntity | PublishableEntity.ID, publish_log_id: int, ) -> PublishableEntityVersion | None: """ @@ -1675,6 +1748,10 @@ def get_published_version_as_of( This is a semi-private function, only available to other apps in the authoring package. """ + entity_id = ( + entity_id.id if isinstance(entity_id, PublishableEntity) + else entity_id + ) record = ( PublishLogRecord.objects.filter( entity_id=entity_id, @@ -1687,7 +1764,7 @@ def get_published_version_as_of( def bulk_draft_changes_for( - learning_package_id: LearningPackage.ID, + learning_package_id: LearningPackage | LearningPackage.ID, changed_by: int | None = None, changed_at: datetime | None = None ) -> DraftChangeLogContext: @@ -1720,6 +1797,10 @@ def bulk_draft_changes_for( with bulk_draft_changes_for(component.learning_package.id): update_one_component(component.learning_package.id, component) """ + learning_package_id = ( + learning_package_id.id if isinstance(learning_package_id, LearningPackage) + else learning_package_id + ) if not changed_at: changed_at = datetime.now(tz=timezone.utc) return DraftChangeLogContext( diff --git a/src/openedx_core/__init__.py b/src/openedx_core/__init__.py index 863a0e058..a835233dc 100644 --- a/src/openedx_core/__init__.py +++ b/src/openedx_core/__init__.py @@ -6,4 +6,4 @@ """ # The version for the entire repository -__version__ = "0.46.0" +__version__ = "0.47.0" diff --git a/tests/openedx_django_lib/__init__.py b/tests/openedx_django_lib/__init__.py new file mode 100644 index 000000000..e69de29bb