diff --git a/ming/encryption.py b/ming/encryption.py index aec597f..f47e511 100644 --- a/ming/encryption.py +++ b/ming/encryption.py @@ -3,6 +3,7 @@ from typing import TYPE_CHECKING, TypeVar, Generic from ming.utils import classproperty +from ming.base import Object as BaseObject import ming.schema if TYPE_CHECKING: @@ -15,6 +16,106 @@ class MingEncryptionError(Exception): pass +class EncryptedObject(BaseObject): + """A dict-like object that supports DecryptedField behavior for nested encrypted fields. + + This class extends :class:`ming.base.Object` and provides automatic decryption/encryption + when accessing fields that have a corresponding encrypted counterpart. + """ + + __slots__ = ('_decrypted_fields', '_encr_func', '_decr_func') + + def __init__(self, data=None, decrypted_fields=None, encr_func=None, decr_func=None): + """ + :param data: Initial data for the object + :param decrypted_fields: Dict mapping decrypted field names to their DecryptedField instances + :param encr_func: Function to encrypt values (datastore.encr) + :param decr_func: Function to decrypt values (datastore.decr) + """ + super().__init__(data or {}) + object.__setattr__(self, '_decrypted_fields', decrypted_fields or {}) + object.__setattr__(self, '_encr_func', encr_func) + object.__setattr__(self, '_decr_func', decr_func) + + def __getattr__(self, name): + # Check if this is a decrypted field + decrypted_fields = object.__getattribute__(self, '_decrypted_fields') + if name in decrypted_fields: + decr_func = object.__getattribute__(self, '_decr_func') + decrypted_field = decrypted_fields[name] + encrypted_value = self.get(decrypted_field.encrypted_field) + if decr_func is not None and encrypted_value is not None: + return decr_func(encrypted_value) + return encrypted_value + + # Fall back to standard Object behavior + try: + return self[name] + except KeyError: + raise AttributeError(name) + + def __setattr__(self, name, value): + # Check if this is a decrypted field + decrypted_fields = object.__getattribute__(self, '_decrypted_fields') + if name in decrypted_fields: + encr_func = object.__getattribute__(self, '_encr_func') + decrypted_field = decrypted_fields[name] + + # Type check + if value is not None and not isinstance(value, decrypted_field.field_type): + raise TypeError(f'not {decrypted_field.field_type}, got {value!r}') + + # Encrypt and store + if encr_func is not None and value is not None: + encrypted_value = encr_func(value) + else: + encrypted_value = value + self[decrypted_field.encrypted_field] = encrypted_value + return + + # Fall back to standard Object behavior + if name in self.__class__.__dict__: + super().__setattr__(name, value) + else: + self[name] = value + + def __getitem__(self, name): + # Check if this is a decrypted field accessed via dict notation + decrypted_fields = object.__getattribute__(self, '_decrypted_fields') + if name in decrypted_fields: + decr_func = object.__getattribute__(self, '_decr_func') + decrypted_field = decrypted_fields[name] + encrypted_field_name = decrypted_field.encrypted_field + if encrypted_field_name in self: + encrypted_value = dict.__getitem__(self, encrypted_field_name) + else: + encrypted_value = None + if decr_func is not None and encrypted_value is not None: + return decr_func(encrypted_value) + return encrypted_value + return dict.__getitem__(self, name) + + def __setitem__(self, name, value): + # Check if this is a decrypted field accessed via dict notation + decrypted_fields = object.__getattribute__(self, '_decrypted_fields') + if name in decrypted_fields: + encr_func = object.__getattribute__(self, '_encr_func') + decrypted_field = decrypted_fields[name] + + # Type check + if value is not None and not isinstance(value, decrypted_field.field_type): + raise TypeError(f'not {decrypted_field.field_type}, got {value!r}') + + # Encrypt and store + if encr_func is not None and value is not None: + encrypted_value = encr_func(value) + else: + encrypted_value = value + dict.__setitem__(self, decrypted_field.encrypted_field, encrypted_value) + return + dict.__setitem__(self, name, value) + + class EncryptionConfig: """ A class to hold the encryption configuration for a ming datastore. @@ -203,11 +304,47 @@ def encrypt_some_fields(cls, data: dict) -> dict: :param data: a dictionary of data to be encrypted :return: a modified copy of the ``data`` param with the currently-unencrypted-but-encryptable fields replaced with ``_encrypted`` counterparts. """ + from ming.declarative import Document + from ming.odm.declarative import MappedClass + encrypted_data = data.copy() + + # Handle top-level decrypted fields for fld in cls.decrypted_field_names(): if fld in encrypted_data: val = encrypted_data.pop(fld) encrypted_data[f'{fld}_encrypted'] = cls.encr(val) + + # Handle nested EncryptedObject fields + if issubclass(cls, Document): + schema = cls.m.schema + elif issubclass(cls, MappedClass): + schema = cls.query.mapper.collection.m.schema + else: + return encrypted_data + + # Check each field in the schema for EncryptedObjectSchema + if hasattr(schema, 'fields'): + for field_name, field_schema in schema.fields.items(): + if field_name in encrypted_data and isinstance(encrypted_data[field_name], dict): + # Check if this field has an EncryptedObjectSchema + from ming.schema import EncryptedObjectSchema + if isinstance(field_schema, EncryptedObjectSchema): + # Recursively encrypt nested fields + nested_data = encrypted_data[field_name] + encrypted_nested = {} + + # Copy over all existing fields + encrypted_nested.update(nested_data) + + # Encrypt decrypted fields in the nested dict + for decrypted_name, decrypted_field in field_schema._decrypted_fields.items(): + if decrypted_name in encrypted_nested: + val = encrypted_nested.pop(decrypted_name) + encrypted_nested[decrypted_field.encrypted_field] = cls.encr(val) + + encrypted_data[field_name] = encrypted_nested + return encrypted_data def decrypt_some_fields(self) -> dict: diff --git a/ming/schema.py b/ming/schema.py index d2136a8..2788165 100644 --- a/ming/schema.py +++ b/ming/schema.py @@ -118,7 +118,15 @@ def make(cls, field, *args, **kwargs): else: raise ValueError('Array must have 0-1 elements') elif isinstance(field, dict): - field = Object(field, *args, **kwargs) + # Check if the dict contains any DecryptedField instances + from ming.encryption import DecryptedField + has_decrypted_fields = any( + isinstance(v, DecryptedField) for v in field.values() + ) + if has_decrypted_fields: + field = EncryptedObjectSchema(field, *args, **kwargs) + else: + field = Object(field, *args, **kwargs) elif field is None: field = Anything(*args, **kwargs) elif field in SHORTHAND: @@ -400,6 +408,62 @@ def extend(self, other): self.fields.update(other.fields) +class EncryptedObjectSchema(Object): + """Schema for dict-like objects that contain :class:`ming.encryption.DecryptedField` instances. + + This schema extends :class:`Object` and provides support for nested encrypted fields. + When a dict schema contains DecryptedField instances, this schema is automatically used + instead of the regular Object schema. + + Example:: + + profile = Field(dict( + first_name=DecryptedField(str, 'first_name_encrypted'), + first_name_encrypted=ming.schema.Binary + )) + + The DecryptedField values in the dict are stored but not validated as schema items. + Instead, they are used to configure the resulting EncryptedObject to handle + encryption/decryption on access. + """ + + def __init__(self, fields=None, required=False, if_missing=NoDefault): + if fields is None: + fields = {} + + # Separate DecryptedField instances from regular schema fields + from ming.encryption import DecryptedField + self._decrypted_fields = {} + regular_fields = {} + + for name, field in fields.items(): + if isinstance(field, DecryptedField): + self._decrypted_fields[name] = field + else: + regular_fields[name] = field + + # Initialize parent with only regular fields + super().__init__(regular_fields, required, if_missing) + + def _validate(self, d, allow_extra=False, strip_extra=False): + # First, validate using parent class + result = super()._validate(d, allow_extra=allow_extra, strip_extra=strip_extra) + + # Convert to EncryptedObject if we have decrypted fields + if self._decrypted_fields: + from ming.encryption import EncryptedObject + # Get encryption functions from the context (passed through validation chain) + # These will be set by the parent Document's validation + encrypted_result = EncryptedObject( + result, + decrypted_fields=self._decrypted_fields, + encr_func=None, # Will be set by parent + decr_func=None # Will be set by parent + ) + return encrypted_result + return result + + class Document(Object): """Specializes :class:`Object` adding polymorphic validation. @@ -454,12 +518,36 @@ def _validate(self, d, allow_extra=False, strip_extra=False): cls = self.get_polymorphic_cls(d) if cls is None or cls == self.managed_class: result = cls.__new__(cls) - result.update(super()._validate( - d, allow_extra=allow_extra, strip_extra=strip_extra)) + validated_data = super()._validate( + d, allow_extra=allow_extra, strip_extra=strip_extra) + result.update(validated_data) + # Inject encryption functions into nested EncryptedObject instances + self._inject_encryption_funcs(result, cls) return result return cls.m.make( d, allow_extra=allow_extra, strip_extra=strip_extra) + def _inject_encryption_funcs(self, obj, cls): + """Recursively inject encryption functions into nested EncryptedObject instances.""" + from ming.encryption import EncryptedObject, EncryptedMixin + + if not issubclass(cls, EncryptedMixin): + return + + for key, value in obj.items(): + if isinstance(value, EncryptedObject): + # Inject encryption functions from the document class + object.__setattr__(value, '_encr_func', cls.encr) + object.__setattr__(value, '_decr_func', cls.decr) + elif isinstance(value, dict): + # Recursively process nested dicts + self._inject_encryption_funcs(value, cls) + elif isinstance(value, list): + # Process lists of dicts + for item in value: + if isinstance(item, (dict, EncryptedObject)): + self._inject_encryption_funcs(item, cls) + def set_polymorphic(self, field, registry, identity): """Configure polymorphic behaviour (except for ``.managed_class``). diff --git a/ming/tests/test_encryption.py b/ming/tests/test_encryption.py index 2ab6e7b..58eb33f 100644 --- a/ming/tests/test_encryption.py +++ b/ming/tests/test_encryption.py @@ -281,6 +281,66 @@ class __mongometa__: self.assertEqual(doc.name, None) self.assertEqual(doc.name_encrypted, None) + def test_nested_decrypted_field(self): + """Test DecryptedField inside a nested dict field.""" + class TestDocNested(Document): + class __mongometa__: + name='test_doc_nested' + session = ming.Session.by_name('test_db') + _id = Field(S.Anything) + profile = Field(dict( + first_name=DecryptedField(str, 'first_name_encrypted'), + first_name_encrypted=S.Binary, + last_name=DecryptedField(str, 'last_name_encrypted'), + last_name_encrypted=S.Binary, + age=int # non-encrypted field + )) + + # Create document using make_encr with unencrypted field names + doc = TestDocNested.make_encr(dict( + _id=1, + profile=dict( + first_name='Jerome', + last_name='Smith', + age=30 + ) + )) + doc.m.save() + + # Test reading decrypted values + self.assertEqual(doc.profile.first_name, 'Jerome') + self.assertEqual(doc.profile.last_name, 'Smith') + self.assertEqual(doc.profile.age, 30) + + # Test that encrypted values are stored correctly + self.assertIsInstance(doc.profile.first_name_encrypted, bytes) + self.assertEqual(doc.profile.first_name_encrypted, TestDocNested.encr('Jerome')) + + # Verify that only encrypted fields are in the underlying dict (not the virtual decrypted fields) + self.assertIn('first_name_encrypted', doc.profile) + self.assertIn('last_name_encrypted', doc.profile) + self.assertNotIn('first_name', dict.keys(doc.profile)) + self.assertNotIn('last_name', dict.keys(doc.profile)) + + # Test setting decrypted values + doc.profile.first_name = 'Jane' + doc.m.save() + self.assertEqual(doc.profile.first_name, 'Jane') + self.assertEqual(doc.profile.first_name_encrypted, TestDocNested.encr('Jane')) + + # Test using dict-style access + doc.profile['last_name'] = 'Doe' + doc.m.save() + self.assertEqual(doc.profile['last_name'], 'Doe') + self.assertEqual(doc.profile['last_name_encrypted'], TestDocNested.encr('Doe')) + + # Test setting None + doc.profile.first_name = None + doc.m.save() + self.assertEqual(doc.profile.first_name, None) + self.assertEqual(doc.profile.first_name_encrypted, None) + + class TestDocumentEncryptionMimAutoSettings(TestDocumentEncryption): def setUp(self): # replace super() NOT using it