From be57c83da0e5693ac857e3f2d68831d68e6b572f Mon Sep 17 00:00:00 2001 From: Richard Tibbles Date: Fri, 21 Feb 2025 07:56:33 -0800 Subject: [PATCH 1/6] Add regression test and fix bug where license_descriptions were not synced with license_id. --- .../contentcuration/tests/test_sync.py | 38 +++++++++++++++++++ contentcuration/contentcuration/utils/sync.py | 1 + 2 files changed, 39 insertions(+) diff --git a/contentcuration/contentcuration/tests/test_sync.py b/contentcuration/contentcuration/tests/test_sync.py index 3a2ba590c5..923d8ad541 100644 --- a/contentcuration/contentcuration/tests/test_sync.py +++ b/contentcuration/contentcuration/tests/test_sync.py @@ -19,6 +19,7 @@ from contentcuration.models import Channel from contentcuration.models import ContentTag from contentcuration.models import File +from contentcuration.models import License from contentcuration.tests import testdata from contentcuration.tests.base import StudioAPITestCase from contentcuration.tests.viewsets.base import generate_create_event @@ -346,6 +347,43 @@ def test_sync_channel_titles_and_descriptions(self): for key, value in labels.items(): self.assertEqual(getattr(target_child, key), value) + def test_sync_license_description(self): + """ + Test that the license description field is synced correctly + Added as a regression test, as this was previously omitted. + """ + self.assertFalse(self.channel.has_changes()) + self.assertFalse(self.derivative_channel.has_changes()) + + contentnode = ( + self.channel.main_tree.get_descendants() + .exclude(kind_id=content_kinds.TOPIC) + .first() + ) + + special_permissions_license = License.objects.get(license_name="Special Permissions") + + contentnode.license = special_permissions_license + contentnode.license_description = "You cannot use this content on a Thursday" + contentnode.copyright_holder = "Thursday's child has far to go" + contentnode.save() + + sync_channel( + self.derivative_channel, + sync_titles_and_descriptions=False, + sync_resource_details=True, + sync_files=False, + sync_assessment_items=False, + ) + + target_child = self.derivative_channel.main_tree.get_descendants().get( + source_node_id=contentnode.node_id + ) + + self.assertEqual(target_child.license, special_permissions_license) + self.assertEqual(target_child.license_description, "You cannot use this content on a Thursday") + self.assertEqual(target_child.copyright_holder, "Thursday's child has far to go") + def test_sync_channel_other_metadata_labels(self): """ Test that calling sync channel will also bring in other metadata label updates. diff --git a/contentcuration/contentcuration/utils/sync.py b/contentcuration/contentcuration/utils/sync.py index 5b3664002a..1fb2dda566 100644 --- a/contentcuration/contentcuration/utils/sync.py +++ b/contentcuration/contentcuration/utils/sync.py @@ -70,6 +70,7 @@ def sync_node( if sync_resource_details: fields = [ "license_id", + "license_description", "copyright_holder", "author", "extra_fields", From 2f70517496f66ca0b960a00bdea02f01c834b27f Mon Sep 17 00:00:00 2001 From: Richard Tibbles Date: Fri, 21 Feb 2025 12:05:12 -0800 Subject: [PATCH 2/6] Reinstate source field rectification but only for license description. --- Makefile | 2 +- ..._rectify_source_field_migraiton_command.py | 167 ++++++++++++++++++ ...ify_incorrect_contentnode_source_fields.py | 127 +++++++++++++ 3 files changed, 295 insertions(+), 1 deletion(-) create mode 100644 contentcuration/contentcuration/tests/test_rectify_source_field_migraiton_command.py create mode 100644 contentcuration/kolibri_public/management/commands/rectify_incorrect_contentnode_source_fields.py diff --git a/Makefile b/Makefile index 051053bab3..619fcee41e 100644 --- a/Makefile +++ b/Makefile @@ -38,7 +38,7 @@ migrate: # 4) Remove the management command from this `deploy-migrate` recipe # 5) Repeat! deploy-migrate: - echo "Nothing to do here!" + python contentcuration/manage.py rectify_incorrect_contentnode_source_fields contentnodegc: python contentcuration/manage.py garbage_collect diff --git a/contentcuration/contentcuration/tests/test_rectify_source_field_migraiton_command.py b/contentcuration/contentcuration/tests/test_rectify_source_field_migraiton_command.py new file mode 100644 index 0000000000..f4643a87ac --- /dev/null +++ b/contentcuration/contentcuration/tests/test_rectify_source_field_migraiton_command.py @@ -0,0 +1,167 @@ +# DELETE THIS FILE AFTER RUNNING THE MIGRATIONSSS +import datetime +import uuid + +from django.core.management import call_command +from django.utils import timezone +from le_utils.constants import content_kinds + +from contentcuration.models import Channel +from contentcuration.models import ContentNode +from contentcuration.models import License +from contentcuration.tests import testdata +from contentcuration.tests.base import StudioAPITestCase +from contentcuration.utils.publish import publish_channel + + +class TestRectifyMigrationCommand(StudioAPITestCase): + + @classmethod + def setUpClass(cls): + super(TestRectifyMigrationCommand, cls).setUpClass() + + def tearDown(self): + super(TestRectifyMigrationCommand, self).tearDown() + + def setUp(self): + super(TestRectifyMigrationCommand, self).setUp() + self.original_channel = testdata.channel() + self.license_original = License.objects.get(license_name="Special Permissions") + self.license_description_original = "License to chill" + self.original_contentnode = ContentNode.objects.create( + id=uuid.uuid4().hex, + title="Original Node", + parent=self.original_channel.main_tree, + license=self.license_original, + license_description=self.license_description_original, + original_channel_id=None, + source_channel_id=None, + author="old author" + ) + self.user = testdata.user() + self.original_channel.editors.add(self.user) + self.client.force_authenticate(user=self.user) + + def create_base_channel_and_contentnode(self, source_contentnode, source_channel): + base_channel = testdata.channel() + base_channel.public = True + base_channel.save() + base_node = ContentNode.objects.create( + id=uuid.uuid4().hex, + title="base contentnode", + parent=base_channel.main_tree, + kind_id=content_kinds.VIDEO, + original_channel_id=self.original_channel.id, + original_source_node_id=self.original_contentnode.node_id, + source_channel_id=source_channel.id, + source_node_id=source_contentnode.node_id, + author="source author", + license=self.license_original, + license_description=None, + ) + return base_node, base_channel + + def create_source_channel_and_contentnode(self): + source_channel = testdata.channel() + source_channel.public = True + source_channel.save() + source_node = ContentNode.objects.create( + id=uuid.uuid4().hex, + title="base contentnode", + parent=source_channel.main_tree, + kind_id=content_kinds.VIDEO, + license=self.license_original, + license_description="No chill", + original_channel_id=self.original_channel.id, + source_channel_id=self.original_channel.id, + source_node_id=self.original_contentnode.node_id, + original_source_node_id=self.original_contentnode.node_id, + author="source author", + ) + + return source_node, source_channel + + def run_migrations(self): + call_command('rectify_incorrect_contentnode_source_fields', user_id=self.user.id, is_test=True) + + def test_two_node_case(self): + base_node, base_channel = self.create_base_channel_and_contentnode(self.original_contentnode, self.original_channel) + + publish_channel(self.user.id, Channel.objects.get(pk=base_channel.pk).id) + + # main_tree node still has changed=true even after the publish + for node in Channel.objects.get(pk=base_channel.pk).main_tree.get_family().filter(changed=True): + node.changed = False + # This should probably again change the changed=true but suprisingly it doesnot + # Meaning the changed boolean doesnot change for the main_tree no matter what we do + # through ContentNode model methods like save. + node.save() + + ContentNode.objects.filter(pk=base_node.pk).update( + modified=datetime.datetime(2023, 7, 5, tzinfo=timezone.utc) + ) + + self.run_migrations() + updated_base_node = ContentNode.objects.get(pk=base_node.pk) + self.assertEqual(updated_base_node.license_description, self.original_contentnode.license_description) + self.assertEqual(Channel.objects.get(pk=base_channel.id).main_tree.get_family().filter(changed=True).exists(), False) + + def test_three_node_case_implicit(self): + source_node, source_channel = self.create_source_channel_and_contentnode() + base_node, base_channel = self.create_base_channel_and_contentnode(source_node, source_channel) + source_node.aggregator = "Nami" + source_node.save() + # Implicit case + base_node.author = source_node.author + base_node.license = source_node.license + base_node.aggregator = source_node.aggregator + base_node.save() + + publish_channel(self.user.id, Channel.objects.get(pk=base_channel.pk).id) + + for node in Channel.objects.get(pk=base_channel.pk).main_tree.get_family().filter(changed=True): + node.changed = False + node.save() + + ContentNode.objects.filter(pk=base_node.pk).update( + modified=datetime.datetime(2023, 7, 5, tzinfo=timezone.utc) + ) + + ContentNode.objects.filter(pk=source_node.pk).update( + modified=datetime.datetime(2023, 3, 5, tzinfo=timezone.utc) + ) + + self.run_migrations() + updated_base_node = ContentNode.objects.get(pk=base_node.pk) + updated_source_node = ContentNode.objects.get(pk=source_node.pk) + self.assertEqual(updated_base_node.license_description, self.original_contentnode.license_description) + self.assertEqual(updated_source_node.license_description, self.original_contentnode.license_description) + self.assertEqual(Channel.objects.get(pk=base_channel.id).main_tree.get_family().filter(changed=True).exists(), False) + + def test_three_node_case_explicit(self): + source_node, source_channel = self.create_source_channel_and_contentnode() + base_node, base_channel = self.create_base_channel_and_contentnode(source_node, source_channel) + source_node.license_description = "luffy" + base_node.license_description = "zoro" + base_node.save() + source_node.save() + publish_channel(self.user.id, Channel.objects.get(pk=base_channel.pk).id) + + for node in Channel.objects.get(pk=base_channel.pk).main_tree.get_family().filter(changed=True): + node.changed = False + node.save() + + ContentNode.objects.filter(pk=base_node.pk).update( + modified=datetime.datetime(2023, 7, 5, tzinfo=timezone.utc) + ) + + ContentNode.objects.filter(pk=source_node.pk).update( + modified=datetime.datetime(2023, 3, 5, tzinfo=timezone.utc) + ) + + self.run_migrations() + updated_base_node = ContentNode.objects.get(pk=base_node.pk) + updated_source_node = ContentNode.objects.get(pk=source_node.pk) + self.assertEqual(updated_base_node.license_description, self.original_contentnode.license_description) + self.assertEqual(updated_source_node.license_description, self.original_contentnode.license_description) + self.assertEqual(Channel.objects.get(pk=base_channel.id).main_tree.get_family().filter(changed=True).exists(), False) diff --git a/contentcuration/kolibri_public/management/commands/rectify_incorrect_contentnode_source_fields.py b/contentcuration/kolibri_public/management/commands/rectify_incorrect_contentnode_source_fields.py new file mode 100644 index 0000000000..24d41b961e --- /dev/null +++ b/contentcuration/kolibri_public/management/commands/rectify_incorrect_contentnode_source_fields.py @@ -0,0 +1,127 @@ +import logging + +from django.core.management.base import BaseCommand +from django.db.models import Exists +from django.db.models import F +from django.db.models import OuterRef +from django.db.models import Value +from django.db.models.functions import Coalesce +from django_cte import With + +from contentcuration.models import Channel +from contentcuration.models import ContentNode +from contentcuration.models import User +from contentcuration.utils.publish import publish_channel + +logger = logging.getLogger(__file__) + + +class Command(BaseCommand): + + def add_arguments(self, parser): + + parser.add_argument( + '--is_test', + action='store_true', + help="Indicate if the command is running in a test environment.", + ) + + parser.add_argument( + '--user_id', + type=int, + help="User ID for the operation", + ) + + def handle(self, *args, **options): + + is_test = options['is_test'] + user_id = options['user_id'] + + if not is_test: + user_id = User.objects.get(email='channeladmin@learningequality.org').pk + + main_trees_cte = With( + ( + Channel.objects.filter( + main_tree__isnull=False + ) + .annotate(channel_id=F("id")) + .values("channel_id", "deleted", tree_id=F("main_tree__tree_id")) + ), + name="main_trees", + ) + + nodes = main_trees_cte.join( + ContentNode.objects.all(), + tree_id=main_trees_cte.col.tree_id, + ).annotate(channel_id=main_trees_cte.col.channel_id, deleted=main_trees_cte.col.deleted) + + original_source_nodes = ( + nodes.with_cte(main_trees_cte) + .filter( + node_id=OuterRef("original_source_node_id"), + ) + .exclude( + tree_id=OuterRef("tree_id"), + ) + .annotate( + coalesced_license_description=Coalesce("license_description", Value("")), + ) + ) + diff = ( + nodes.with_cte(main_trees_cte).filter( + deleted=False, # we dont want the channel to be deleted or else we are fixing ghost nodes + source_node_id__isnull=False, + original_source_node_id__isnull=False, + ) + ).annotate( + coalesced_license_description=Coalesce("license_description", Value("")), + ) + diff_combined = diff.annotate( + original_source_node_f_changed=Exists( + original_source_nodes.exclude( + coalesced_license_description=OuterRef("coalesced_license_description") + ) + ) + ).filter(original_source_node_f_changed=True) + + final_nodes = diff_combined.values( + "id", + "channel_id", + "original_channel_id", + "original_source_node_id", + ).order_by() + + channel_ids_to_republish = set() + + for item in final_nodes: + base_node = ContentNode.objects.get(pk=item["id"]) + + original_source_channel_id = item["original_channel_id"] + original_source_node_id = item["original_source_node_id"] + tree_id = ( + Channel.objects.filter(pk=original_source_channel_id) + .values_list("main_tree__tree_id", flat=True) + .get() + ) + original_source_node = ContentNode.objects.filter( + tree_id=tree_id, node_id=original_source_node_id + ) + + base_channel = Channel.objects.get(pk=item['channel_id']) + + to_be_republished = not (base_channel.main_tree.get_family().filter(changed=True).exists()) + + if original_source_channel_id is not None and original_source_node.exists(): + # original source node exists and its license_description doesn't match + # update the base node + if base_node.license_description != original_source_node[0].license_description: + base_node.license_description = original_source_node[0].license_description + base_node.save() + + if to_be_republished and base_channel.last_published is not None: + channel_ids_to_republish.add(base_channel.id) + + # we would republish the channel + for channel_id in channel_ids_to_republish: + publish_channel(user_id, channel_id) From 6a366e29c73c093850cd113fb8748a65e4e20dc1 Mon Sep 17 00:00:00 2001 From: Richard Tibbles Date: Wed, 26 Feb 2025 07:38:57 -0800 Subject: [PATCH 3/6] Don't republish automatically. --- ..._rectify_source_field_migraiton_command.py | 8 ++--- ...ify_incorrect_contentnode_source_fields.py | 35 ------------------- 2 files changed, 4 insertions(+), 39 deletions(-) diff --git a/contentcuration/contentcuration/tests/test_rectify_source_field_migraiton_command.py b/contentcuration/contentcuration/tests/test_rectify_source_field_migraiton_command.py index f4643a87ac..96382e25af 100644 --- a/contentcuration/contentcuration/tests/test_rectify_source_field_migraiton_command.py +++ b/contentcuration/contentcuration/tests/test_rectify_source_field_migraiton_command.py @@ -82,7 +82,7 @@ def create_source_channel_and_contentnode(self): return source_node, source_channel def run_migrations(self): - call_command('rectify_incorrect_contentnode_source_fields', user_id=self.user.id, is_test=True) + call_command('rectify_incorrect_contentnode_source_fields') def test_two_node_case(self): base_node, base_channel = self.create_base_channel_and_contentnode(self.original_contentnode, self.original_channel) @@ -104,7 +104,7 @@ def test_two_node_case(self): self.run_migrations() updated_base_node = ContentNode.objects.get(pk=base_node.pk) self.assertEqual(updated_base_node.license_description, self.original_contentnode.license_description) - self.assertEqual(Channel.objects.get(pk=base_channel.id).main_tree.get_family().filter(changed=True).exists(), False) + self.assertEqual(Channel.objects.get(pk=base_channel.id).main_tree.get_family().filter(changed=True).exists(), True) def test_three_node_case_implicit(self): source_node, source_channel = self.create_source_channel_and_contentnode() @@ -136,7 +136,7 @@ def test_three_node_case_implicit(self): updated_source_node = ContentNode.objects.get(pk=source_node.pk) self.assertEqual(updated_base_node.license_description, self.original_contentnode.license_description) self.assertEqual(updated_source_node.license_description, self.original_contentnode.license_description) - self.assertEqual(Channel.objects.get(pk=base_channel.id).main_tree.get_family().filter(changed=True).exists(), False) + self.assertEqual(Channel.objects.get(pk=base_channel.id).main_tree.get_family().filter(changed=True).exists(), True) def test_three_node_case_explicit(self): source_node, source_channel = self.create_source_channel_and_contentnode() @@ -164,4 +164,4 @@ def test_three_node_case_explicit(self): updated_source_node = ContentNode.objects.get(pk=source_node.pk) self.assertEqual(updated_base_node.license_description, self.original_contentnode.license_description) self.assertEqual(updated_source_node.license_description, self.original_contentnode.license_description) - self.assertEqual(Channel.objects.get(pk=base_channel.id).main_tree.get_family().filter(changed=True).exists(), False) + self.assertEqual(Channel.objects.get(pk=base_channel.id).main_tree.get_family().filter(changed=True).exists(), True) diff --git a/contentcuration/kolibri_public/management/commands/rectify_incorrect_contentnode_source_fields.py b/contentcuration/kolibri_public/management/commands/rectify_incorrect_contentnode_source_fields.py index 24d41b961e..c4d40af4c0 100644 --- a/contentcuration/kolibri_public/management/commands/rectify_incorrect_contentnode_source_fields.py +++ b/contentcuration/kolibri_public/management/commands/rectify_incorrect_contentnode_source_fields.py @@ -10,36 +10,14 @@ from contentcuration.models import Channel from contentcuration.models import ContentNode -from contentcuration.models import User -from contentcuration.utils.publish import publish_channel logger = logging.getLogger(__file__) class Command(BaseCommand): - def add_arguments(self, parser): - - parser.add_argument( - '--is_test', - action='store_true', - help="Indicate if the command is running in a test environment.", - ) - - parser.add_argument( - '--user_id', - type=int, - help="User ID for the operation", - ) - def handle(self, *args, **options): - is_test = options['is_test'] - user_id = options['user_id'] - - if not is_test: - user_id = User.objects.get(email='channeladmin@learningequality.org').pk - main_trees_cte = With( ( Channel.objects.filter( @@ -92,8 +70,6 @@ def handle(self, *args, **options): "original_source_node_id", ).order_by() - channel_ids_to_republish = set() - for item in final_nodes: base_node = ContentNode.objects.get(pk=item["id"]) @@ -108,20 +84,9 @@ def handle(self, *args, **options): tree_id=tree_id, node_id=original_source_node_id ) - base_channel = Channel.objects.get(pk=item['channel_id']) - - to_be_republished = not (base_channel.main_tree.get_family().filter(changed=True).exists()) - if original_source_channel_id is not None and original_source_node.exists(): # original source node exists and its license_description doesn't match # update the base node if base_node.license_description != original_source_node[0].license_description: base_node.license_description = original_source_node[0].license_description base_node.save() - - if to_be_republished and base_channel.last_published is not None: - channel_ids_to_republish.add(base_channel.id) - - # we would republish the channel - for channel_id in channel_ids_to_republish: - publish_channel(user_id, channel_id) From 888c4499a9e45149fbecd345140345efbe9076d8 Mon Sep 17 00:00:00 2001 From: Blaine Jester Date: Thu, 27 Feb 2025 12:11:57 -0800 Subject: [PATCH 4/6] Allow read of production Studio bucket --- .../contentcuration/production_settings.py | 2 +- .../contentcuration/sandbox_settings.py | 2 +- contentcuration/contentcuration/settings.py | 3 +- .../contentcuration/tests/test_gcs_storage.py | 114 +++++++++++++++--- .../contentcuration/utils/gcs_storage.py | 88 ++++++++++++-- .../contentcuration/utils/storage_common.py | 3 +- 6 files changed, 183 insertions(+), 29 deletions(-) diff --git a/contentcuration/contentcuration/production_settings.py b/contentcuration/contentcuration/production_settings.py index da6a199125..1d0a7d456d 100644 --- a/contentcuration/contentcuration/production_settings.py +++ b/contentcuration/contentcuration/production_settings.py @@ -12,7 +12,7 @@ MEDIA_ROOT = base_settings.STORAGE_ROOT -DEFAULT_FILE_STORAGE = 'contentcuration.utils.gcs_storage.GoogleCloudStorage' +DEFAULT_FILE_STORAGE = 'contentcuration.utils.gcs_storage.CompositeGCS' SESSION_ENGINE = "django.contrib.sessions.backends.db" # email settings diff --git a/contentcuration/contentcuration/sandbox_settings.py b/contentcuration/contentcuration/sandbox_settings.py index 2ca766f4ef..61e00a465f 100644 --- a/contentcuration/contentcuration/sandbox_settings.py +++ b/contentcuration/contentcuration/sandbox_settings.py @@ -3,7 +3,7 @@ DEBUG = True -DEFAULT_FILE_STORAGE = "contentcuration.utils.gcs_storage.GoogleCloudStorage" +DEFAULT_FILE_STORAGE = "contentcuration.utils.gcs_storage.CompositeGCS" LANGUAGES += (("ar", gettext("Arabic")),) # noqa diff --git a/contentcuration/contentcuration/settings.py b/contentcuration/contentcuration/settings.py index 62825373b3..595ee834fb 100644 --- a/contentcuration/contentcuration/settings.py +++ b/contentcuration/contentcuration/settings.py @@ -292,8 +292,9 @@ def gettext(s): # ('en-PT', gettext('English - Pirate')), ) +PRODUCTION_SITE_ID = 1 SITE_BY_ID = { - 'master': 1, + 'master': PRODUCTION_SITE_ID, 'unstable': 3, 'hotfixes': 4, } diff --git a/contentcuration/contentcuration/tests/test_gcs_storage.py b/contentcuration/contentcuration/tests/test_gcs_storage.py index 8af75f4f8e..165877f9ac 100755 --- a/contentcuration/contentcuration/tests/test_gcs_storage.py +++ b/contentcuration/contentcuration/tests/test_gcs_storage.py @@ -1,18 +1,15 @@ -#!/usr/bin/env python -from future import standard_library -standard_library.install_aliases() from io import BytesIO -import pytest +import mock from django.core.files import File from django.test import TestCase +from google.cloud.storage import Bucket from google.cloud.storage import Client from google.cloud.storage.blob import Blob from mixer.main import mixer -from mock import create_autospec -from mock import patch -from contentcuration.utils.gcs_storage import GoogleCloudStorage as gcs +from contentcuration.utils.gcs_storage import CompositeGCS +from contentcuration.utils.gcs_storage import GoogleCloudStorage class GoogleCloudStorageSaveTestCase(TestCase): @@ -21,10 +18,10 @@ class GoogleCloudStorageSaveTestCase(TestCase): """ def setUp(self): - self.blob_class = create_autospec(Blob) + self.blob_class = mock.create_autospec(Blob) self.blob_obj = self.blob_class("blob", "blob") - self.mock_client = create_autospec(Client) - self.storage = gcs(client=self.mock_client()) + self.mock_client = mock.create_autospec(Client) + self.storage = GoogleCloudStorage(client=self.mock_client(), bucket_name="bucket") self.content = BytesIO(b"content") def test_calls_upload_from_file(self): @@ -73,8 +70,8 @@ def test_uploads_cache_control_private_if_content_database(self): self.storage.save(filename, self.content, blob_object=self.blob_obj) assert "private" in self.blob_obj.cache_control - @patch("contentcuration.utils.gcs_storage.BytesIO") - @patch("contentcuration.utils.gcs_storage.GoogleCloudStorage._is_file_empty", return_value=False) + @mock.patch("contentcuration.utils.gcs_storage.BytesIO") + @mock.patch("contentcuration.utils.gcs_storage.GoogleCloudStorage._is_file_empty", return_value=False) def test_gzip_if_content_database(self, bytesio_mock, file_empty_mock): """ Check that if we're uploading a gzipped content database and @@ -99,17 +96,17 @@ class RandomFileSchema: filename = str def setUp(self): - self.blob_class = create_autospec(Blob) + self.blob_class = mock.create_autospec(Blob) self.blob_obj = self.blob_class("blob", "blob") - self.mock_client = create_autospec(Client) - self.storage = gcs(client=self.mock_client()) + self.mock_client = mock.create_autospec(Client) + self.storage = GoogleCloudStorage(client=self.mock_client(), bucket_name="bucket") self.local_file = mixer.blend(self.RandomFileSchema) def test_raises_error_if_mode_is_not_rb(self): """ open() should raise an assertion error if passed in a mode flag that's not "rb". """ - with pytest.raises(AssertionError): + with self.assertRaises(AssertionError): self.storage.open("randfile", mode="wb") def test_calls_blob_download_to_file(self): @@ -130,3 +127,88 @@ def test_returns_django_file(self): assert isinstance(f, File) # This checks that an actual temp file was written on disk for the file.git assert f.name + + +class CompositeGCSTestCase(TestCase): + """ + Tests for the GoogleCloudStorage class. + """ + + def setUp(self): + mock_client_cls = mock.MagicMock(spec_set=Client) + bucket_cls = mock.MagicMock(spec_set=Bucket) + self.blob_cls = mock.MagicMock(spec_set=Blob) + + self.mock_default_client = mock_client_cls(project="project") + self.mock_anon_client = mock_client_cls(project=None) + + self.mock_default_bucket = bucket_cls(self.mock_default_client, "bucket") + self.mock_default_client.get_bucket.return_value = self.mock_default_bucket + self.mock_anon_bucket = bucket_cls(self.mock_anon_client, "bucket") + self.mock_anon_client.get_bucket.return_value = self.mock_anon_bucket + + with mock.patch("contentcuration.utils.gcs_storage._create_default_client", return_value=self.mock_default_client), \ + mock.patch("contentcuration.utils.gcs_storage.Client.create_anonymous_client", return_value=self.mock_anon_client): + self.storage = CompositeGCS() + + def test_get_writeable_backend(self): + backend = self.storage._get_writeable_backend() + self.assertEqual(backend.client, self.mock_default_client) + + def test_get_writeable_backend__raises_error_if_none(self): + self.mock_default_client.project = None + with self.assertRaises(AssertionError): + self.storage._get_writeable_backend() + + def test_get_readonly_backend(self): + self.mock_anon_bucket.get_blob.return_value = self.blob_cls("blob", "blob") + backend = self.storage._get_readable_backend("blob") + self.assertEqual(backend.client, self.mock_anon_client) + + def test_get_readonly_backend__raises_error_if_not_found(self): + self.mock_default_bucket.get_blob.return_value = None + self.mock_anon_bucket.get_blob.return_value = None + with self.assertRaises(FileNotFoundError): + self.storage._get_readable_backend("blob") + + def test_open(self): + self.mock_default_bucket.get_blob.return_value = self.blob_cls("blob", "blob") + f = self.storage.open("blob") + self.assertIsInstance(f, File) + self.mock_default_bucket.get_blob.assert_called_with("blob") + + @mock.patch("contentcuration.utils.gcs_storage.Blob") + def test_save(self, mock_blob): + self.storage.save("blob", BytesIO(b"content")) + blob = mock_blob.return_value + blob.upload_from_file.assert_called() + + def test_delete(self): + mock_blob = self.blob_cls("blob", "blob") + self.mock_default_bucket.get_blob.return_value = mock_blob + self.storage.delete("blob") + mock_blob.delete.assert_called() + + def test_exists(self): + self.mock_default_bucket.get_blob.return_value = self.blob_cls("blob", "blob") + self.assertTrue(self.storage.exists("blob")) + + def test_exists__returns_false_if_not_found(self): + self.mock_default_bucket.get_blob.return_value = None + self.assertFalse(self.storage.exists("blob")) + + def test_size(self): + mock_blob = self.blob_cls("blob", "blob") + self.mock_default_bucket.get_blob.return_value = mock_blob + mock_blob.size = 4 + self.assertEqual(self.storage.size("blob"), 4) + + def test_url(self): + mock_blob = self.blob_cls("blob", "blob") + self.mock_default_bucket.get_blob.return_value = mock_blob + mock_blob.public_url = "https://storage.googleapis.com/bucket/blob" + self.assertEqual(self.storage.url("blob"), "https://storage.googleapis.com/bucket/blob") + + def test_get_created_time(self): + self.mock_default_bucket.get_blob.return_value = self.blob_cls("blob", "blob") + self.assertEqual(self.storage.get_created_time("blob"), self.blob_cls.return_value.time_created) diff --git a/contentcuration/contentcuration/utils/gcs_storage.py b/contentcuration/contentcuration/utils/gcs_storage.py index 07d8e899f3..896af6ed11 100644 --- a/contentcuration/contentcuration/utils/gcs_storage.py +++ b/contentcuration/contentcuration/utils/gcs_storage.py @@ -18,16 +18,24 @@ MAX_RETRY_TIME = 60 # seconds -class GoogleCloudStorage(Storage): - def __init__(self, client=None): +def _create_default_client(service_account_credentials_path=settings.GCS_STORAGE_SERVICE_ACCOUNT_KEY_PATH): + if service_account_credentials_path: + return Client.from_service_account_json(service_account_credentials_path) + return Client() + - self.client = client if client else self._create_default_client() - self.bucket = self.client.get_bucket(settings.AWS_S3_BUCKET_NAME) +class GoogleCloudStorage(Storage): + def __init__(self, client, bucket_name): + self.client = client + self.bucket = self.client.get_bucket(bucket_name) - def _create_default_client(self, service_account_credentials_path=settings.GCS_STORAGE_SERVICE_ACCOUNT_KEY_PATH): - if service_account_credentials_path: - return Client.from_service_account_json(service_account_credentials_path) - return Client() + @property + def writeable(self): + """ + See `Client.create_anonymous_client()` + :return: True if the client has a project set, False otherwise. + """ + return self.client.project is not None def open(self, name, mode="rb", blob_object=None): """ @@ -79,7 +87,7 @@ def exists(self, name): :return: True if the resource with the name exists, or False otherwise. """ blob = self.bucket.get_blob(name) - return blob + return blob is not None def size(self, name): blob = self.bucket.get_blob(name) @@ -199,3 +207,65 @@ def _is_file_empty(fobj): byt = fobj.read(1) fobj.seek(current_location) return len(byt) == 0 + + +class CompositeGCS(Storage): + def __init__(self): + self.backends = [] + # Only add the studio-content bucket (the production bucket) if we're not in production + if settings.SITE_ID != settings.PRODUCTION_SITE_ID: + self.backends.append(GoogleCloudStorage(Client.create_anonymous_client(), "studio-content")) + self.backends.append(GoogleCloudStorage(_create_default_client(), settings.AWS_S3_BUCKET_NAME)) + + def _get_writeable_backend(self): + """ + :rtype: GoogleCloudStorage + """ + for backend in self.backends: + if backend.writeable: + return backend + raise AssertionError("No writeable backend found") + + def _get_readable_backend(self, name): + """ + :rtype: GoogleCloudStorage + """ + for backend in self.backends: + if backend.exists(name): + return backend + raise FileNotFoundError("{} not found".format(name)) + + def open(self, name, mode='rb'): + return self._get_readable_backend(name).open(name, mode) + + def save(self, name, content, max_length=None): + return self._get_writeable_backend().save(name, content, max_length=max_length) + + def delete(self, name): + self._get_writeable_backend().delete(name) + + def exists(self, name): + try: + self._get_readable_backend(name) + return True + except FileNotFoundError: + return False + + def listdir(self, path): + # This method was not implemented on GoogleCloudStorage to begin with + raise NotImplementedError("listdir is not implemented for CompositeGCS") + + def size(self, name): + return self._get_readable_backend(name).size(name) + + def url(self, name): + return self._get_readable_backend(name).url(name) + + def get_accessed_time(self, name): + return self._get_readable_backend(name).get_accessed_time(name) + + def get_created_time(self, name): + return self._get_readable_backend(name).get_created_time(name) + + def get_modified_time(self, name): + return self._get_readable_backend(name).get_modified_time(name) diff --git a/contentcuration/contentcuration/utils/storage_common.py b/contentcuration/contentcuration/utils/storage_common.py index b41b018511..1e393adc70 100644 --- a/contentcuration/contentcuration/utils/storage_common.py +++ b/contentcuration/contentcuration/utils/storage_common.py @@ -6,6 +6,7 @@ from django.core.files.storage import default_storage from django_s3_storage.storage import S3Storage +from .gcs_storage import CompositeGCS from .gcs_storage import GoogleCloudStorage @@ -61,7 +62,7 @@ def get_presigned_upload_url( # both storage types are having difficulties enforcing it. mimetype = determine_content_type(filepath) - if isinstance(storage, GoogleCloudStorage): + if isinstance(storage, (GoogleCloudStorage, CompositeGCS)): client = client or storage.client bucket = settings.AWS_S3_BUCKET_NAME upload_url = _get_gcs_presigned_put_url(client, bucket, filepath, md5sum_b64, lifetime_sec, mimetype=mimetype) From 894e65060798930a5fa6e3927dfcf8a72582e094 Mon Sep 17 00:00:00 2001 From: Blaine Jester Date: Thu, 27 Feb 2025 12:58:21 -0800 Subject: [PATCH 5/6] Swap order of backends --- contentcuration/contentcuration/utils/gcs_storage.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/contentcuration/contentcuration/utils/gcs_storage.py b/contentcuration/contentcuration/utils/gcs_storage.py index 896af6ed11..740921055b 100644 --- a/contentcuration/contentcuration/utils/gcs_storage.py +++ b/contentcuration/contentcuration/utils/gcs_storage.py @@ -212,10 +212,10 @@ def _is_file_empty(fobj): class CompositeGCS(Storage): def __init__(self): self.backends = [] + self.backends.append(GoogleCloudStorage(_create_default_client(), settings.AWS_S3_BUCKET_NAME)) # Only add the studio-content bucket (the production bucket) if we're not in production if settings.SITE_ID != settings.PRODUCTION_SITE_ID: self.backends.append(GoogleCloudStorage(Client.create_anonymous_client(), "studio-content")) - self.backends.append(GoogleCloudStorage(_create_default_client(), settings.AWS_S3_BUCKET_NAME)) def _get_writeable_backend(self): """ From ed87e46e6c1c31f1bc1392b8bcd1dd3f60f85548 Mon Sep 17 00:00:00 2001 From: Blaine Jester Date: Fri, 28 Feb 2025 07:14:34 -0800 Subject: [PATCH 6/6] Allow access to client externally --- contentcuration/contentcuration/utils/gcs_storage.py | 6 ++++++ contentcuration/contentcuration/utils/storage_common.py | 2 +- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/contentcuration/contentcuration/utils/gcs_storage.py b/contentcuration/contentcuration/utils/gcs_storage.py index 740921055b..9ec21a3886 100644 --- a/contentcuration/contentcuration/utils/gcs_storage.py +++ b/contentcuration/contentcuration/utils/gcs_storage.py @@ -29,6 +29,9 @@ def __init__(self, client, bucket_name): self.client = client self.bucket = self.client.get_bucket(bucket_name) + def get_client(self): + return self.client + @property def writeable(self): """ @@ -235,6 +238,9 @@ def _get_readable_backend(self, name): return backend raise FileNotFoundError("{} not found".format(name)) + def get_client(self): + return self._get_writeable_backend().get_client() + def open(self, name, mode='rb'): return self._get_readable_backend(name).open(name, mode) diff --git a/contentcuration/contentcuration/utils/storage_common.py b/contentcuration/contentcuration/utils/storage_common.py index 1e393adc70..f2ba6e3188 100644 --- a/contentcuration/contentcuration/utils/storage_common.py +++ b/contentcuration/contentcuration/utils/storage_common.py @@ -63,7 +63,7 @@ def get_presigned_upload_url( mimetype = determine_content_type(filepath) if isinstance(storage, (GoogleCloudStorage, CompositeGCS)): - client = client or storage.client + client = client or storage.get_client() bucket = settings.AWS_S3_BUCKET_NAME upload_url = _get_gcs_presigned_put_url(client, bucket, filepath, md5sum_b64, lifetime_sec, mimetype=mimetype) elif isinstance(storage, S3Storage):