From 8dc3240fef2798db3d623d74f2c01272b65f3d5f Mon Sep 17 00:00:00 2001 From: Kyle D McCormick Date: Mon, 27 Apr 2026 12:59:02 -0400 Subject: [PATCH 1/3] feat: Allow IDs or Models to be passed to all publishing APIs Part of: https://github.com/openedx/openedx-core/issues/562 --- src/openedx_content/applets/publishing/api.py | 153 ++++++++++-------- src/openedx_core/__init__.py | 2 +- src/openedx_django_lib/typing.py | 25 +++ tests/openedx_django_lib/__init__.py | 0 tests/openedx_django_lib/test_typing.py | 35 ++++ tests/test_django_app/models.py | 32 +++- 6 files changed, 176 insertions(+), 71 deletions(-) create mode 100644 src/openedx_django_lib/typing.py create mode 100644 tests/openedx_django_lib/__init__.py create mode 100644 tests/openedx_django_lib/test_typing.py diff --git a/src/openedx_content/applets/publishing/api.py b/src/openedx_content/applets/publishing/api.py index 14a71e418..f5589c29e 100644 --- a/src/openedx_content/applets/publishing/api.py +++ b/src/openedx_content/applets/publishing/api.py @@ -10,7 +10,7 @@ from contextlib import nullcontext from datetime import datetime, timezone from functools import partial -from typing import ContextManager, Optional, cast +from typing import ContextManager, Optional, Sequence, cast from django.contrib.auth import get_user_model from django.core.exceptions import ObjectDoesNotExist @@ -18,6 +18,7 @@ from django.db.transaction import atomic, on_commit from openedx_django_lib.fields import create_hash_digest +from openedx_django_lib.typing import get_model_id from . import signals from .contextmanagers import DraftChangeLogContext @@ -76,10 +77,12 @@ ] -def get_learning_package(learning_package_id: LearningPackage.ID, /) -> LearningPackage: +def get_learning_package(learning_package_id: LearningPackage | LearningPackage.ID, /) -> LearningPackage: """ Get LearningPackage by ID. """ + if isinstance(learning_package_id, LearningPackage): + return learning_package_id return LearningPackage.objects.get(id=learning_package_id) @@ -129,7 +132,7 @@ def send_event(): def update_learning_package( - learning_package_id: LearningPackage.ID, + learning_package: LearningPackage | LearningPackage.ID, /, package_ref: str | None = None, title: str | None = None, @@ -141,34 +144,34 @@ def update_learning_package( Note that LearningPackage itself is not versioned (only stuff inside it is). """ - lp = LearningPackage.objects.get(id=learning_package_id) + learning_package = get_learning_package(learning_package) # If no changes were requested, there's nothing to update, so just return # the LearningPackage as-is. if all(field is None for field in [package_ref, title, description, updated]): - return lp + return learning_package if package_ref is not None: - lp.package_ref = package_ref + learning_package.package_ref = package_ref if title is not None: - lp.title = title + learning_package.title = title if description is not None: - lp.description = description + learning_package.description = description # updated is a bit different–we auto-generate it if it's not explicitly # passed in. if updated is None: updated = datetime.now(tz=timezone.utc) - lp.updated = updated + learning_package.updated = updated - lp.save() + learning_package.save() # Emit LEARNING_PACKAGE_UPDATED once the transaction commits. Note: we only # reach this point if at least one of key/title/description/updated was # passed in (the early-return above handles the no-op case), so the update # really did touch the row. - lp_id = lp.id - lp_title = lp.title + lp_id = learning_package.id + lp_title = learning_package.title def send_event(): signals.LEARNING_PACKAGE_UPDATED.send_event( @@ -176,8 +179,7 @@ def send_event(): ) on_commit(send_event) - - return lp + return learning_package def learning_package_exists(package_ref: str) -> bool: @@ -188,7 +190,7 @@ def learning_package_exists(package_ref: str) -> bool: def create_publishable_entity( - learning_package_id: LearningPackage.ID, + learning_package_id: LearningPackage | LearningPackage.ID, /, entity_ref: str, created: datetime, @@ -203,6 +205,7 @@ def create_publishable_entity( You'd typically want to call this right before creating your own content model that points to it. """ + learning_package_id = get_model_id(learning_package_id) return PublishableEntity.objects.create( learning_package_id=learning_package_id, entity_ref=entity_ref, @@ -213,14 +216,14 @@ def create_publishable_entity( def create_publishable_entity_version( - entity_id: PublishableEntity.ID, + entity_id: PublishableEntity | PublishableEntity.ID, /, version_num: int, title: str, created: datetime, created_by: int | None, *, - dependencies: list[PublishableEntity.ID] | None = None, + dependencies: Sequence[PublishableEntity.ID] | None = None, ) -> PublishableEntityVersion: """ Create a PublishableEntityVersion. @@ -228,6 +231,7 @@ def create_publishable_entity_version( You'd typically want to call this right before creating your own content version model that points to it. """ + entity_id = get_model_id(entity_id) with atomic(savepoint=False): version = PublishableEntityVersion.objects.create( entity_id=entity_id, @@ -249,9 +253,9 @@ def create_publishable_entity_version( def set_version_dependencies( - version_id: int, # PublishableEntityVersion.id, + version_id: PublishableEntityVersion | int, /, - dependencies: list[PublishableEntity.ID], + dependencies: Sequence[PublishableEntity.ID], ) -> None: """ Set the dependencies of a publishable entity version. @@ -301,6 +305,7 @@ def set_version_dependencies( will cause recalculation to the higher levels that depend on it. 3. Do not create circular dependencies. """ + version_id = get_model_id(version_id) PublishableEntityVersionDependency.objects.bulk_create( [ PublishableEntityVersionDependency( @@ -312,18 +317,31 @@ def set_version_dependencies( ) -def get_publishable_entity(publishable_entity_id: PublishableEntity.ID, /) -> PublishableEntity: - return PublishableEntity.objects.get(pk=publishable_entity_id) +def get_publishable_entity(entity_id: PublishableEntity | PublishableEntity.ID, /) -> PublishableEntity: + """ + Get a learning package by its database primary key. + """ + if isinstance(entity_id, PublishableEntity): + return entity_id + return PublishableEntity.objects.get(id=entity_id) -def get_publishable_entity_by_ref(learning_package_id: LearningPackage.ID, /, entity_ref: str) -> PublishableEntity: +def get_publishable_entity_by_ref( + learning_package_id: LearningPackage | LearningPackage.ID, /, + entity_ref: str, +) -> PublishableEntity: + """ + Given a learning package and an entity_ref, get the matching publishable entity. + """ + learning_package_id = get_model_id(learning_package_id) return PublishableEntity.objects.get( learning_package_id=learning_package_id, entity_ref=entity_ref, ) -def get_last_publish(learning_package_id: LearningPackage.ID, /) -> PublishLog | None: +def get_last_publish(learning_package_id: LearningPackage | LearningPackage.ID, /) -> PublishLog | None: + learning_package_id = get_model_id(learning_package_id) return PublishLog.objects \ .filter(learning_package_id=learning_package_id) \ .order_by('-id') \ @@ -337,10 +355,13 @@ def get_all_drafts(learning_package_id: LearningPackage.ID, /) -> QuerySet[Draft ) -def get_publishable_entities(learning_package_id: LearningPackage.ID, /) -> QuerySet[PublishableEntity]: +def get_publishable_entities( + learning_package_id: LearningPackage | LearningPackage.ID, / +) -> QuerySet[PublishableEntity]: """ Get all entities in a learning package. """ + learning_package_id = get_model_id(learning_package_id) return ( PublishableEntity.objects .filter(learning_package_id=learning_package_id) @@ -352,7 +373,7 @@ def get_publishable_entities(learning_package_id: LearningPackage.ID, /) -> Quer def get_entities_with_unpublished_changes( - learning_package_id: LearningPackage.ID, + learning_package_id: LearningPackage | LearningPackage.ID, /, include_deleted_drafts: bool = False ) -> QuerySet[PublishableEntity]: @@ -362,6 +383,7 @@ def get_entities_with_unpublished_changes( By default, this excludes soft-deleted drafts but can be included using include_deleted_drafts option. """ + learning_package_id = get_model_id(learning_package_id) entities_qs = ( PublishableEntity.objects .filter(learning_package_id=learning_package_id) @@ -384,12 +406,15 @@ def get_entities_with_unpublished_changes( return entities_qs.exclude(draft__version__isnull=True) -def get_entities_with_unpublished_deletes(learning_package_id: LearningPackage.ID, /) -> QuerySet[PublishableEntity]: +def get_entities_with_unpublished_deletes( + learning_package_id: LearningPackage | LearningPackage.ID, / +) -> QuerySet[PublishableEntity]: """ Something will become "deleted" if it has a null Draft version but a not-null Published version. (If both are null, it means it's already been deleted in a previous publish, or it was never published.) """ + learning_package_id = get_model_id(learning_package_id) return PublishableEntity.objects \ .filter( learning_package_id=learning_package_id, @@ -398,7 +423,7 @@ def get_entities_with_unpublished_deletes(learning_package_id: LearningPackage.I def publish_all_drafts( - learning_package_id: LearningPackage.ID, + learning_package_id: LearningPackage | LearningPackage.ID, /, message="", published_at: datetime | None = None, @@ -407,6 +432,7 @@ def publish_all_drafts( """ Publish everything that is a Draft and is not already published. """ + learning_package_id = get_model_id(learning_package_id) draft_qset = ( Draft.objects .filter(entity__learning_package_id=learning_package_id) @@ -451,7 +477,7 @@ def _get_dependencies_with_unpublished_changes( def publish_from_drafts( - learning_package_id: LearningPackage.ID, + learning_package_id: LearningPackage | LearningPackage.ID, /, draft_qset: QuerySet[Draft], message: str = "", @@ -466,6 +492,7 @@ def publish_from_drafts( By default, this will also publish all dependencies (e.g. unpinned children) of the Drafts that are passed in. """ + learning_package_id = get_model_id(learning_package_id) if published_at is None: published_at = datetime.now(tz=timezone.utc) @@ -598,7 +625,7 @@ def get_published_version( def get_entity_draft_history( - publishable_entity_or_id: PublishableEntity | int, / + entity_id: PublishableEntity | PublishableEntity.ID, / ) -> QuerySet[DraftChangeLogRecord]: """ [ 🛑 UNSTABLE ] @@ -617,11 +644,7 @@ def get_entity_draft_history( soft-delete DraftChangeLogRecord (new_version=None) is included because it was made after the last real publish. """ - if isinstance(publishable_entity_or_id, int): - entity_id = PublishableEntity.PublishableEntityID(publishable_entity_or_id) - else: - entity_id = publishable_entity_or_id.id - + entity_id = get_model_id(entity_id) qs = ( DraftChangeLogRecord.objects .filter(entity_id=entity_id) @@ -666,7 +689,7 @@ def get_entity_draft_history( def get_entity_publish_history( - publishable_entity_or_id: PublishableEntity | int, / + entity_id: PublishableEntity | PublishableEntity.ID, / ) -> QuerySet[PublishLogRecord]: """ [ 🛑 UNSTABLE ] @@ -681,11 +704,7 @@ def get_entity_publish_history( PublishLogRecord captures only the version that was actually published, not the intermediate draft versions. """ - if isinstance(publishable_entity_or_id, int): - entity_id = PublishableEntity.PublishableEntityID(publishable_entity_or_id) - else: - entity_id = publishable_entity_or_id.id - + entity_id = get_model_id(entity_id) return ( PublishLogRecord.objects .filter(entity_id=entity_id) @@ -699,7 +718,7 @@ def get_entity_publish_history( def get_entity_publish_history_entries( - publishable_entity_or_id: PublishableEntity | int, + entity_id: PublishableEntity | PublishableEntity.ID, /, publish_log_uuid: str, ) -> QuerySet[DraftChangeLogRecord]: @@ -728,10 +747,7 @@ def get_entity_publish_history_entries( Raises PublishLogRecord.DoesNotExist if publish_log_uuid is not found for this entity. """ - if isinstance(publishable_entity_or_id, int): - entity_id = PublishableEntity.PublishableEntityID(publishable_entity_or_id) - else: - entity_id = publishable_entity_or_id.id + entity_id = get_model_id(entity_id) # Fetch the PublishLogRecord for the requested PublishLog pub_record = ( @@ -796,7 +812,7 @@ def get_entity_publish_history_entries( def get_entity_version_contributors( - publishable_entity_or_id: PublishableEntity | int, + entity_id: PublishableEntity | PublishableEntity.ID, /, old_version_num: int, new_version_num: int | None, @@ -821,10 +837,7 @@ def get_entity_version_contributors( - A user who contributed multiple versions in the range appears only once (results are deduplicated with DISTINCT). """ - if isinstance(publishable_entity_or_id, int): - entity_id = PublishableEntity.PublishableEntityID(publishable_entity_or_id) - else: - entity_id = publishable_entity_or_id.id + entity_id = get_model_id(entity_id) if new_version_num is not None: version_filter = Q( @@ -866,8 +879,8 @@ def get_entity_version_contributors( def set_draft_version( - draft_or_id: Draft | PublishableEntity.ID, - publishable_entity_version_pk: int | None, + draft: Draft | PublishableEntity.ID, + entity_version_id: PublishableEntityVersion | int | None, /, set_at: datetime | None = None, set_by: int | None = None, # User.id @@ -897,27 +910,26 @@ def set_draft_version( """ if set_at is None: set_at = datetime.now(tz=timezone.utc) + entity_version_id = get_model_id(entity_version_id) if entity_version_id else None with atomic(savepoint=False): - if isinstance(draft_or_id, Draft): - draft = draft_or_id - elif isinstance(draft_or_id, int): + if isinstance(draft, int): draft, _created = Draft.objects.select_related("entity") \ - .get_or_create(entity_id=draft_or_id) - else: - class_name = draft_or_id.__class__.__name__ + .get_or_create(entity_id=draft) + elif not isinstance(draft, Draft): + class_name = draft.__class__.__name__ raise TypeError( f"draft_or_id must be a Draft or int, not ({class_name})" ) # If the Draft is already pointing at this version, there's nothing to do. old_version_id = draft.version_id - if old_version_id == publishable_entity_version_pk: + if old_version_id == entity_version_id: return # The actual update of the Draft model is here. Everything after this # block is bookkeeping in our DraftChangeLog. - draft.version_id = publishable_entity_version_pk + draft.version_id = entity_version_id # Check to see if we're inside a context manager for an active # DraftChangeLog (i.e. what happens if the caller is using the public @@ -931,7 +943,7 @@ def set_draft_version( active_change_log, draft.entity_id, old_version_id=old_version_id, - new_version_id=publishable_entity_version_pk, + new_version_id=entity_version_id, ) if draft_log_record: # Normal case: a DraftChangeLogRecord was created or updated. @@ -984,7 +996,7 @@ def set_draft_version( draft_change_log=change_log, entity_id=draft.entity_id, old_version_id=old_version_id, - new_version_id=publishable_entity_version_pk, + new_version_id=entity_version_id, ) draft.save() _create_side_effects_for_change_log(change_log) @@ -994,7 +1006,7 @@ def set_draft_version( def _add_to_existing_draft_change_log( active_change_log: DraftChangeLog, - entity_id: PublishableEntity.ID, + entity_id: PublishableEntity | PublishableEntity.ID, old_version_id: int | None, new_version_id: int | None, ) -> DraftChangeLogRecord | None: @@ -1021,6 +1033,7 @@ def _add_to_existing_draft_change_log( log records (the only place where it's normal to have the same old and new versions). """ + entity_id = get_model_id(entity_id) try: # Check to see if this PublishableEntity has already been changed in # this DraftChangeLog. If so, we update that record instead of creating @@ -1523,7 +1536,9 @@ def hash_for_log_record( return digest -def soft_delete_draft(publishable_entity_id: PublishableEntity.ID, /, deleted_by: int | None = None) -> None: +def soft_delete_draft( + publishable_entity_id: PublishableEntity | PublishableEntity.ID, /, deleted_by: int | None = None +) -> None: """ Sets the Draft version to None. @@ -1533,11 +1548,12 @@ def soft_delete_draft(publishable_entity_id: PublishableEntity.ID, /, deleted_by of pointing the Draft back to the most recent ``PublishableEntityVersion`` for a given ``PublishableEntity``. """ + publishable_entity_id = get_model_id(publishable_entity_id) return set_draft_version(publishable_entity_id, None, set_by=deleted_by) def reset_drafts_to_published( - learning_package_id: LearningPackage.ID, + learning_package_id: LearningPackage | LearningPackage.ID, /, reset_at: datetime | None = None, reset_by: int | None = None, # User.id @@ -1566,6 +1582,7 @@ def reset_drafts_to_published( latest version created for a PublishableEntity (its ``latest`` attribute), rather than basing it off of the version that Draft points to. """ + learning_package_id = get_model_id(learning_package_id) if reset_at is None: reset_at = datetime.now(tz=timezone.utc) @@ -1665,7 +1682,7 @@ def filter_publishable_entities( def get_published_version_as_of( - entity_id: PublishableEntity.ID, + entity_id: PublishableEntity | PublishableEntity.ID, publish_log_id: int, ) -> PublishableEntityVersion | None: """ @@ -1675,6 +1692,7 @@ def get_published_version_as_of( This is a semi-private function, only available to other apps in the authoring package. """ + entity_id = get_model_id(entity_id) record = ( PublishLogRecord.objects.filter( entity_id=entity_id, @@ -1687,7 +1705,7 @@ def get_published_version_as_of( def bulk_draft_changes_for( - learning_package_id: LearningPackage.ID, + learning_package_id: LearningPackage | LearningPackage.ID, changed_by: int | None = None, changed_at: datetime | None = None ) -> DraftChangeLogContext: @@ -1720,6 +1738,7 @@ def bulk_draft_changes_for( with bulk_draft_changes_for(component.learning_package.id): update_one_component(component.learning_package.id, component) """ + learning_package_id = get_model_id(learning_package_id) if not changed_at: changed_at = datetime.now(tz=timezone.utc) return DraftChangeLogContext( diff --git a/src/openedx_core/__init__.py b/src/openedx_core/__init__.py index 863a0e058..a835233dc 100644 --- a/src/openedx_core/__init__.py +++ b/src/openedx_core/__init__.py @@ -6,4 +6,4 @@ """ # The version for the entire repository -__version__ = "0.46.0" +__version__ = "0.47.0" diff --git a/src/openedx_django_lib/typing.py b/src/openedx_django_lib/typing.py new file mode 100644 index 000000000..f9e5b78c6 --- /dev/null +++ b/src/openedx_django_lib/typing.py @@ -0,0 +1,25 @@ +""" +Utilities and types for working with strongly-typed Django code. +""" +import typing as t + +from django.db.models import Model + +_Model_T = t.TypeVar("_Model_T", bound=Model) +_ModelID_T = t.TypeVar("_ModelID_T", bound=int) + + +def get_model_id(model_or_id: _Model_T | _ModelID_T, /) -> _ModelID_T: + """ + Given a variable that could be a model instance or its ID, return the ID. + + Raises a TypeError if called on a model without an `.id` attribute. + Most of our models have `.id` integer PK fields, or `.id` @properties which proxy to a 1-1 model, + but some models (e.g. ManyToManys) do not. + """ + if isinstance(model_or_id, Model): + try: + return t.cast(_ModelID_T, model_or_id.id) # type: ignore[attr-defined] + except AttributeError as exc: + raise TypeError("get_model_id is only valid on models with an `id` field.") from exc + return t.cast(_ModelID_T, model_or_id) diff --git a/tests/openedx_django_lib/__init__.py b/tests/openedx_django_lib/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/openedx_django_lib/test_typing.py b/tests/openedx_django_lib/test_typing.py new file mode 100644 index 000000000..e6665d84e --- /dev/null +++ b/tests/openedx_django_lib/test_typing.py @@ -0,0 +1,35 @@ +""" +Tests for our django typing utils +""" +from typing import assert_type + +from tests.test_django_app.models import MyTypedModel, RelatedTypedModel + +from openedx_django_lib.typing import get_model_id + + +def test_get_model_id() -> None: + """ + Test that get_model_id behaves as expected, both at runtime and during typechecking. + """ + my_model = MyTypedModel() + related_model = RelatedTypedModel(my_model=my_model) + + # Sanity checks + assert_type(my_model, MyTypedModel) + assert_type(my_model.id, MyTypedModel.ID) + assert_type(related_model.my_model, MyTypedModel) + assert_type(related_model.my_model.id, MyTypedModel.ID) + assert_type(related_model.my_model_id, MyTypedModel.ID) + + # get_model_id on a model returns its id + assert get_model_id(my_model) == my_model.id + assert_type(get_model_id(my_model), MyTypedModel.ID) + assert get_model_id(related_model.my_model) == my_model.id + assert_type(get_model_id(related_model.my_model), MyTypedModel.ID) + + # get_model_id on an id returns itself + assert get_model_id(my_model.id) == my_model.id + assert_type(get_model_id(my_model.id), MyTypedModel.ID) + assert get_model_id(related_model.my_model.id) == my_model.id + assert_type(get_model_id(related_model.my_model.id), MyTypedModel.ID) diff --git a/tests/test_django_app/models.py b/tests/test_django_app/models.py index 203be32dd..105f16942 100644 --- a/tests/test_django_app/models.py +++ b/tests/test_django_app/models.py @@ -1,10 +1,8 @@ """ Models that are only for use in tests. - -These models are specifically for testing the `containers` API. """ -from typing import override +from typing import override, NewType from django.core.exceptions import ValidationError from django.db import models @@ -16,6 +14,34 @@ PublishableEntityMixin, PublishableEntityVersionMixin, ) +from openedx_django_lib.fields import TypedBigAutoField + + +class MyTypedModel(models.Model): + """ + A model with nothing but a typed ID field. + """ + MyTypedModelID = NewType("MyTypedModelID", int) + type ID = MyTypedModelID + + class IDField(TypedBigAutoField[ID]): + pass + + id = IDField(primary_key=True) + + +class RelatedTypedModel(models.Model): + """ + A model with nothing but a typed ID field and an FK to another typed model. + """ + MyRelatedModelID = NewType("MyRelatedModelID", int) + type ID = MyRelatedModelID + + class IDField(TypedBigAutoField[ID]): + pass + + id = IDField(primary_key=True) + my_model = models.ForeignKey(MyTypedModel, on_delete=models.CASCADE) class TestEntity(PublishableEntityMixin): From cb7742eb65f9467c712f5b77c7de1b67a443a9e0 Mon Sep 17 00:00:00 2001 From: Kyle D McCormick Date: Tue, 28 Apr 2026 16:23:32 -0400 Subject: [PATCH 2/3] test: Revert related_model assert_types These weren't passing mypy, reverting for now --- tests/openedx_django_lib/test_typing.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/tests/openedx_django_lib/test_typing.py b/tests/openedx_django_lib/test_typing.py index e6665d84e..686a2fe4d 100644 --- a/tests/openedx_django_lib/test_typing.py +++ b/tests/openedx_django_lib/test_typing.py @@ -3,7 +3,7 @@ """ from typing import assert_type -from tests.test_django_app.models import MyTypedModel, RelatedTypedModel +from tests.test_django_app.models import MyTypedModel from openedx_django_lib.typing import get_model_id @@ -13,23 +13,15 @@ def test_get_model_id() -> None: Test that get_model_id behaves as expected, both at runtime and during typechecking. """ my_model = MyTypedModel() - related_model = RelatedTypedModel(my_model=my_model) # Sanity checks assert_type(my_model, MyTypedModel) assert_type(my_model.id, MyTypedModel.ID) - assert_type(related_model.my_model, MyTypedModel) - assert_type(related_model.my_model.id, MyTypedModel.ID) - assert_type(related_model.my_model_id, MyTypedModel.ID) # get_model_id on a model returns its id assert get_model_id(my_model) == my_model.id assert_type(get_model_id(my_model), MyTypedModel.ID) - assert get_model_id(related_model.my_model) == my_model.id - assert_type(get_model_id(related_model.my_model), MyTypedModel.ID) # get_model_id on an id returns itself assert get_model_id(my_model.id) == my_model.id assert_type(get_model_id(my_model.id), MyTypedModel.ID) - assert get_model_id(related_model.my_model.id) == my_model.id - assert_type(get_model_id(related_model.my_model.id), MyTypedModel.ID) From fd4924db3f4a4448e6716be0dabb3c4eff8bc7c2 Mon Sep 17 00:00:00 2001 From: Kyle D McCormick Date: Wed, 29 Apr 2026 00:16:47 -0400 Subject: [PATCH 3/3] fix: Get rid of get_model_id; it doesn't typecheck right --- src/openedx_content/applets/publishing/api.py | 118 +++++++++++++----- src/openedx_django_lib/typing.py | 25 ---- tests/openedx_django_lib/test_typing.py | 27 ---- tests/test_django_app/models.py | 32 +---- 4 files changed, 93 insertions(+), 109 deletions(-) delete mode 100644 src/openedx_django_lib/typing.py delete mode 100644 tests/openedx_django_lib/test_typing.py diff --git a/src/openedx_content/applets/publishing/api.py b/src/openedx_content/applets/publishing/api.py index f5589c29e..5e7e7f867 100644 --- a/src/openedx_content/applets/publishing/api.py +++ b/src/openedx_content/applets/publishing/api.py @@ -18,7 +18,6 @@ from django.db.transaction import atomic, on_commit from openedx_django_lib.fields import create_hash_digest -from openedx_django_lib.typing import get_model_id from . import signals from .contextmanagers import DraftChangeLogContext @@ -205,7 +204,10 @@ def create_publishable_entity( You'd typically want to call this right before creating your own content model that points to it. """ - learning_package_id = get_model_id(learning_package_id) + learning_package_id = ( + learning_package_id.id if isinstance(learning_package_id, LearningPackage) + else learning_package_id + ) return PublishableEntity.objects.create( learning_package_id=learning_package_id, entity_ref=entity_ref, @@ -231,7 +233,10 @@ def create_publishable_entity_version( You'd typically want to call this right before creating your own content version model that points to it. """ - entity_id = get_model_id(entity_id) + entity_id = ( + entity_id.id if isinstance(entity_id, PublishableEntity) + else entity_id + ) with atomic(savepoint=False): version = PublishableEntityVersion.objects.create( entity_id=entity_id, @@ -305,7 +310,10 @@ def set_version_dependencies( will cause recalculation to the higher levels that depend on it. 3. Do not create circular dependencies. """ - version_id = get_model_id(version_id) + version_id = ( + version_id.id if isinstance(version_id, PublishableEntityVersion) + else version_id + ) PublishableEntityVersionDependency.objects.bulk_create( [ PublishableEntityVersionDependency( @@ -333,7 +341,10 @@ def get_publishable_entity_by_ref( """ Given a learning package and an entity_ref, get the matching publishable entity. """ - learning_package_id = get_model_id(learning_package_id) + learning_package_id = ( + learning_package_id.id if isinstance(learning_package_id, LearningPackage) + else learning_package_id + ) return PublishableEntity.objects.get( learning_package_id=learning_package_id, entity_ref=entity_ref, @@ -341,7 +352,13 @@ def get_publishable_entity_by_ref( def get_last_publish(learning_package_id: LearningPackage | LearningPackage.ID, /) -> PublishLog | None: - learning_package_id = get_model_id(learning_package_id) + """ + Get the log of the latest publish in the LearningPackage, or None if nothing has been published. + """ + learning_package_id = ( + learning_package_id.id if isinstance(learning_package_id, LearningPackage) + else learning_package_id + ) return PublishLog.objects \ .filter(learning_package_id=learning_package_id) \ .order_by('-id') \ @@ -361,7 +378,10 @@ def get_publishable_entities( """ Get all entities in a learning package. """ - learning_package_id = get_model_id(learning_package_id) + learning_package_id = ( + learning_package_id.id if isinstance(learning_package_id, LearningPackage) + else learning_package_id + ) return ( PublishableEntity.objects .filter(learning_package_id=learning_package_id) @@ -383,7 +403,10 @@ def get_entities_with_unpublished_changes( By default, this excludes soft-deleted drafts but can be included using include_deleted_drafts option. """ - learning_package_id = get_model_id(learning_package_id) + learning_package_id = ( + learning_package_id.id if isinstance(learning_package_id, LearningPackage) + else learning_package_id + ) entities_qs = ( PublishableEntity.objects .filter(learning_package_id=learning_package_id) @@ -414,7 +437,10 @@ def get_entities_with_unpublished_deletes( not-null Published version. (If both are null, it means it's already been deleted in a previous publish, or it was never published.) """ - learning_package_id = get_model_id(learning_package_id) + learning_package_id = ( + learning_package_id.id if isinstance(learning_package_id, LearningPackage) + else learning_package_id + ) return PublishableEntity.objects \ .filter( learning_package_id=learning_package_id, @@ -432,7 +458,10 @@ def publish_all_drafts( """ Publish everything that is a Draft and is not already published. """ - learning_package_id = get_model_id(learning_package_id) + learning_package_id = ( + learning_package_id.id if isinstance(learning_package_id, LearningPackage) + else learning_package_id + ) draft_qset = ( Draft.objects .filter(entity__learning_package_id=learning_package_id) @@ -492,7 +521,10 @@ def publish_from_drafts( By default, this will also publish all dependencies (e.g. unpinned children) of the Drafts that are passed in. """ - learning_package_id = get_model_id(learning_package_id) + learning_package_id = ( + learning_package_id.id if isinstance(learning_package_id, LearningPackage) + else learning_package_id + ) if published_at is None: published_at = datetime.now(tz=timezone.utc) @@ -644,7 +676,10 @@ def get_entity_draft_history( soft-delete DraftChangeLogRecord (new_version=None) is included because it was made after the last real publish. """ - entity_id = get_model_id(entity_id) + entity_id = ( + entity_id.id if isinstance(entity_id, PublishableEntity) + else entity_id + ) qs = ( DraftChangeLogRecord.objects .filter(entity_id=entity_id) @@ -704,7 +739,10 @@ def get_entity_publish_history( PublishLogRecord captures only the version that was actually published, not the intermediate draft versions. """ - entity_id = get_model_id(entity_id) + entity_id = ( + entity_id.id if isinstance(entity_id, PublishableEntity) + else entity_id + ) return ( PublishLogRecord.objects .filter(entity_id=entity_id) @@ -747,7 +785,10 @@ def get_entity_publish_history_entries( Raises PublishLogRecord.DoesNotExist if publish_log_uuid is not found for this entity. """ - entity_id = get_model_id(entity_id) + entity_id = ( + entity_id.id if isinstance(entity_id, PublishableEntity) + else entity_id + ) # Fetch the PublishLogRecord for the requested PublishLog pub_record = ( @@ -837,7 +878,10 @@ def get_entity_version_contributors( - A user who contributed multiple versions in the range appears only once (results are deduplicated with DISTINCT). """ - entity_id = get_model_id(entity_id) + entity_id = ( + entity_id.id if isinstance(entity_id, PublishableEntity) + else entity_id + ) if new_version_num is not None: version_filter = Q( @@ -880,7 +924,7 @@ def get_entity_version_contributors( def set_draft_version( draft: Draft | PublishableEntity.ID, - entity_version_id: PublishableEntityVersion | int | None, + version_id: PublishableEntityVersion | int | None, /, set_at: datetime | None = None, set_by: int | None = None, # User.id @@ -910,7 +954,10 @@ def set_draft_version( """ if set_at is None: set_at = datetime.now(tz=timezone.utc) - entity_version_id = get_model_id(entity_version_id) if entity_version_id else None + version_id = ( + version_id.id if isinstance(version_id, PublishableEntityVersion) + else version_id + ) with atomic(savepoint=False): if isinstance(draft, int): @@ -924,12 +971,12 @@ def set_draft_version( # If the Draft is already pointing at this version, there's nothing to do. old_version_id = draft.version_id - if old_version_id == entity_version_id: + if old_version_id == version_id: return # The actual update of the Draft model is here. Everything after this # block is bookkeeping in our DraftChangeLog. - draft.version_id = entity_version_id + draft.version_id = version_id # Check to see if we're inside a context manager for an active # DraftChangeLog (i.e. what happens if the caller is using the public @@ -943,7 +990,7 @@ def set_draft_version( active_change_log, draft.entity_id, old_version_id=old_version_id, - new_version_id=entity_version_id, + new_version_id=version_id, ) if draft_log_record: # Normal case: a DraftChangeLogRecord was created or updated. @@ -996,7 +1043,7 @@ def set_draft_version( draft_change_log=change_log, entity_id=draft.entity_id, old_version_id=old_version_id, - new_version_id=entity_version_id, + new_version_id=version_id, ) draft.save() _create_side_effects_for_change_log(change_log) @@ -1033,7 +1080,10 @@ def _add_to_existing_draft_change_log( log records (the only place where it's normal to have the same old and new versions). """ - entity_id = get_model_id(entity_id) + entity_id = ( + entity_id.id if isinstance(entity_id, PublishableEntity) + else entity_id + ) try: # Check to see if this PublishableEntity has already been changed in # this DraftChangeLog. If so, we update that record instead of creating @@ -1537,7 +1587,7 @@ def hash_for_log_record( def soft_delete_draft( - publishable_entity_id: PublishableEntity | PublishableEntity.ID, /, deleted_by: int | None = None + entity_id: PublishableEntity | PublishableEntity.ID, /, deleted_by: int | None = None ) -> None: """ Sets the Draft version to None. @@ -1548,8 +1598,11 @@ def soft_delete_draft( of pointing the Draft back to the most recent ``PublishableEntityVersion`` for a given ``PublishableEntity``. """ - publishable_entity_id = get_model_id(publishable_entity_id) - return set_draft_version(publishable_entity_id, None, set_by=deleted_by) + entity_id = ( + entity_id.id if isinstance(entity_id, PublishableEntity) + else entity_id + ) + return set_draft_version(entity_id, None, set_by=deleted_by) def reset_drafts_to_published( @@ -1582,7 +1635,10 @@ def reset_drafts_to_published( latest version created for a PublishableEntity (its ``latest`` attribute), rather than basing it off of the version that Draft points to. """ - learning_package_id = get_model_id(learning_package_id) + learning_package_id = ( + learning_package_id.id if isinstance(learning_package_id, LearningPackage) + else learning_package_id + ) if reset_at is None: reset_at = datetime.now(tz=timezone.utc) @@ -1692,7 +1748,10 @@ def get_published_version_as_of( This is a semi-private function, only available to other apps in the authoring package. """ - entity_id = get_model_id(entity_id) + entity_id = ( + entity_id.id if isinstance(entity_id, PublishableEntity) + else entity_id + ) record = ( PublishLogRecord.objects.filter( entity_id=entity_id, @@ -1738,7 +1797,10 @@ def bulk_draft_changes_for( with bulk_draft_changes_for(component.learning_package.id): update_one_component(component.learning_package.id, component) """ - learning_package_id = get_model_id(learning_package_id) + learning_package_id = ( + learning_package_id.id if isinstance(learning_package_id, LearningPackage) + else learning_package_id + ) if not changed_at: changed_at = datetime.now(tz=timezone.utc) return DraftChangeLogContext( diff --git a/src/openedx_django_lib/typing.py b/src/openedx_django_lib/typing.py deleted file mode 100644 index f9e5b78c6..000000000 --- a/src/openedx_django_lib/typing.py +++ /dev/null @@ -1,25 +0,0 @@ -""" -Utilities and types for working with strongly-typed Django code. -""" -import typing as t - -from django.db.models import Model - -_Model_T = t.TypeVar("_Model_T", bound=Model) -_ModelID_T = t.TypeVar("_ModelID_T", bound=int) - - -def get_model_id(model_or_id: _Model_T | _ModelID_T, /) -> _ModelID_T: - """ - Given a variable that could be a model instance or its ID, return the ID. - - Raises a TypeError if called on a model without an `.id` attribute. - Most of our models have `.id` integer PK fields, or `.id` @properties which proxy to a 1-1 model, - but some models (e.g. ManyToManys) do not. - """ - if isinstance(model_or_id, Model): - try: - return t.cast(_ModelID_T, model_or_id.id) # type: ignore[attr-defined] - except AttributeError as exc: - raise TypeError("get_model_id is only valid on models with an `id` field.") from exc - return t.cast(_ModelID_T, model_or_id) diff --git a/tests/openedx_django_lib/test_typing.py b/tests/openedx_django_lib/test_typing.py deleted file mode 100644 index 686a2fe4d..000000000 --- a/tests/openedx_django_lib/test_typing.py +++ /dev/null @@ -1,27 +0,0 @@ -""" -Tests for our django typing utils -""" -from typing import assert_type - -from tests.test_django_app.models import MyTypedModel - -from openedx_django_lib.typing import get_model_id - - -def test_get_model_id() -> None: - """ - Test that get_model_id behaves as expected, both at runtime and during typechecking. - """ - my_model = MyTypedModel() - - # Sanity checks - assert_type(my_model, MyTypedModel) - assert_type(my_model.id, MyTypedModel.ID) - - # get_model_id on a model returns its id - assert get_model_id(my_model) == my_model.id - assert_type(get_model_id(my_model), MyTypedModel.ID) - - # get_model_id on an id returns itself - assert get_model_id(my_model.id) == my_model.id - assert_type(get_model_id(my_model.id), MyTypedModel.ID) diff --git a/tests/test_django_app/models.py b/tests/test_django_app/models.py index 105f16942..203be32dd 100644 --- a/tests/test_django_app/models.py +++ b/tests/test_django_app/models.py @@ -1,8 +1,10 @@ """ Models that are only for use in tests. + +These models are specifically for testing the `containers` API. """ -from typing import override, NewType +from typing import override from django.core.exceptions import ValidationError from django.db import models @@ -14,34 +16,6 @@ PublishableEntityMixin, PublishableEntityVersionMixin, ) -from openedx_django_lib.fields import TypedBigAutoField - - -class MyTypedModel(models.Model): - """ - A model with nothing but a typed ID field. - """ - MyTypedModelID = NewType("MyTypedModelID", int) - type ID = MyTypedModelID - - class IDField(TypedBigAutoField[ID]): - pass - - id = IDField(primary_key=True) - - -class RelatedTypedModel(models.Model): - """ - A model with nothing but a typed ID field and an FK to another typed model. - """ - MyRelatedModelID = NewType("MyRelatedModelID", int) - type ID = MyRelatedModelID - - class IDField(TypedBigAutoField[ID]): - pass - - id = IDField(primary_key=True) - my_model = models.ForeignKey(MyTypedModel, on_delete=models.CASCADE) class TestEntity(PublishableEntityMixin):