Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 137 additions & 0 deletions ming/encryption.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import TYPE_CHECKING, TypeVar, Generic

from ming.utils import classproperty
from ming.base import Object as BaseObject
import ming.schema

if TYPE_CHECKING:
Expand All @@ -15,6 +16,106 @@ class MingEncryptionError(Exception):
pass


class EncryptedObject(BaseObject):
"""A dict-like object that supports DecryptedField behavior for nested encrypted fields.

This class extends :class:`ming.base.Object` and provides automatic decryption/encryption
when accessing fields that have a corresponding encrypted counterpart.
"""

__slots__ = ('_decrypted_fields', '_encr_func', '_decr_func')

def __init__(self, data=None, decrypted_fields=None, encr_func=None, decr_func=None):
"""
:param data: Initial data for the object
:param decrypted_fields: Dict mapping decrypted field names to their DecryptedField instances
:param encr_func: Function to encrypt values (datastore.encr)
:param decr_func: Function to decrypt values (datastore.decr)
"""
super().__init__(data or {})
object.__setattr__(self, '_decrypted_fields', decrypted_fields or {})
object.__setattr__(self, '_encr_func', encr_func)
object.__setattr__(self, '_decr_func', decr_func)

def __getattr__(self, name):
# Check if this is a decrypted field
decrypted_fields = object.__getattribute__(self, '_decrypted_fields')
if name in decrypted_fields:
decr_func = object.__getattribute__(self, '_decr_func')
decrypted_field = decrypted_fields[name]
encrypted_value = self.get(decrypted_field.encrypted_field)
if decr_func is not None and encrypted_value is not None:
return decr_func(encrypted_value)
return encrypted_value

# Fall back to standard Object behavior
try:
return self[name]
except KeyError:
raise AttributeError(name)

def __setattr__(self, name, value):
# Check if this is a decrypted field
decrypted_fields = object.__getattribute__(self, '_decrypted_fields')
if name in decrypted_fields:
encr_func = object.__getattribute__(self, '_encr_func')
decrypted_field = decrypted_fields[name]

# Type check
if value is not None and not isinstance(value, decrypted_field.field_type):
raise TypeError(f'not {decrypted_field.field_type}, got {value!r}')

# Encrypt and store
if encr_func is not None and value is not None:
encrypted_value = encr_func(value)
else:
encrypted_value = value
self[decrypted_field.encrypted_field] = encrypted_value
return

# Fall back to standard Object behavior
if name in self.__class__.__dict__:
super().__setattr__(name, value)
else:
self[name] = value

def __getitem__(self, name):
# Check if this is a decrypted field accessed via dict notation
decrypted_fields = object.__getattribute__(self, '_decrypted_fields')
if name in decrypted_fields:
decr_func = object.__getattribute__(self, '_decr_func')
decrypted_field = decrypted_fields[name]
encrypted_field_name = decrypted_field.encrypted_field
if encrypted_field_name in self:
encrypted_value = dict.__getitem__(self, encrypted_field_name)
else:
encrypted_value = None
if decr_func is not None and encrypted_value is not None:
return decr_func(encrypted_value)
return encrypted_value
return dict.__getitem__(self, name)

def __setitem__(self, name, value):
# Check if this is a decrypted field accessed via dict notation
decrypted_fields = object.__getattribute__(self, '_decrypted_fields')
if name in decrypted_fields:
encr_func = object.__getattribute__(self, '_encr_func')
decrypted_field = decrypted_fields[name]

# Type check
if value is not None and not isinstance(value, decrypted_field.field_type):
raise TypeError(f'not {decrypted_field.field_type}, got {value!r}')

# Encrypt and store
if encr_func is not None and value is not None:
encrypted_value = encr_func(value)
else:
encrypted_value = value
dict.__setitem__(self, decrypted_field.encrypted_field, encrypted_value)
return
dict.__setitem__(self, name, value)


class EncryptionConfig:
"""
A class to hold the encryption configuration for a ming datastore.
Expand Down Expand Up @@ -203,11 +304,47 @@ def encrypt_some_fields(cls, data: dict) -> dict:
:param data: a dictionary of data to be encrypted
:return: a modified copy of the ``data`` param with the currently-unencrypted-but-encryptable fields replaced with ``_encrypted`` counterparts.
"""
from ming.declarative import Document
from ming.odm.declarative import MappedClass

encrypted_data = data.copy()

# Handle top-level decrypted fields
for fld in cls.decrypted_field_names():
if fld in encrypted_data:
val = encrypted_data.pop(fld)
encrypted_data[f'{fld}_encrypted'] = cls.encr(val)

# Handle nested EncryptedObject fields
if issubclass(cls, Document):
schema = cls.m.schema
elif issubclass(cls, MappedClass):
schema = cls.query.mapper.collection.m.schema
else:
return encrypted_data

# Check each field in the schema for EncryptedObjectSchema
if hasattr(schema, 'fields'):
for field_name, field_schema in schema.fields.items():
if field_name in encrypted_data and isinstance(encrypted_data[field_name], dict):
# Check if this field has an EncryptedObjectSchema
from ming.schema import EncryptedObjectSchema
if isinstance(field_schema, EncryptedObjectSchema):
# Recursively encrypt nested fields
nested_data = encrypted_data[field_name]
encrypted_nested = {}

# Copy over all existing fields
encrypted_nested.update(nested_data)

# Encrypt decrypted fields in the nested dict
for decrypted_name, decrypted_field in field_schema._decrypted_fields.items():
if decrypted_name in encrypted_nested:
val = encrypted_nested.pop(decrypted_name)
encrypted_nested[decrypted_field.encrypted_field] = cls.encr(val)

encrypted_data[field_name] = encrypted_nested

return encrypted_data

def decrypt_some_fields(self) -> dict:
Expand Down
94 changes: 91 additions & 3 deletions ming/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,15 @@ def make(cls, field, *args, **kwargs):
else:
raise ValueError('Array must have 0-1 elements')
elif isinstance(field, dict):
field = Object(field, *args, **kwargs)
# Check if the dict contains any DecryptedField instances
from ming.encryption import DecryptedField
has_decrypted_fields = any(
isinstance(v, DecryptedField) for v in field.values()
)
if has_decrypted_fields:
field = EncryptedObjectSchema(field, *args, **kwargs)
else:
field = Object(field, *args, **kwargs)
elif field is None:
field = Anything(*args, **kwargs)
elif field in SHORTHAND:
Expand Down Expand Up @@ -400,6 +408,62 @@ def extend(self, other):
self.fields.update(other.fields)


class EncryptedObjectSchema(Object):
"""Schema for dict-like objects that contain :class:`ming.encryption.DecryptedField` instances.

This schema extends :class:`Object` and provides support for nested encrypted fields.
When a dict schema contains DecryptedField instances, this schema is automatically used
instead of the regular Object schema.

Example::

profile = Field(dict(
first_name=DecryptedField(str, 'first_name_encrypted'),
first_name_encrypted=ming.schema.Binary
))

The DecryptedField values in the dict are stored but not validated as schema items.
Instead, they are used to configure the resulting EncryptedObject to handle
encryption/decryption on access.
"""

def __init__(self, fields=None, required=False, if_missing=NoDefault):
if fields is None:
fields = {}

# Separate DecryptedField instances from regular schema fields
from ming.encryption import DecryptedField
self._decrypted_fields = {}
regular_fields = {}

for name, field in fields.items():
if isinstance(field, DecryptedField):
self._decrypted_fields[name] = field
else:
regular_fields[name] = field

# Initialize parent with only regular fields
super().__init__(regular_fields, required, if_missing)

def _validate(self, d, allow_extra=False, strip_extra=False):
# First, validate using parent class
result = super()._validate(d, allow_extra=allow_extra, strip_extra=strip_extra)

# Convert to EncryptedObject if we have decrypted fields
if self._decrypted_fields:
from ming.encryption import EncryptedObject
# Get encryption functions from the context (passed through validation chain)
# These will be set by the parent Document's validation
encrypted_result = EncryptedObject(
result,
decrypted_fields=self._decrypted_fields,
encr_func=None, # Will be set by parent
decr_func=None # Will be set by parent
)
return encrypted_result
return result


class Document(Object):
"""Specializes :class:`Object` adding polymorphic validation.

Expand Down Expand Up @@ -454,12 +518,36 @@ def _validate(self, d, allow_extra=False, strip_extra=False):
cls = self.get_polymorphic_cls(d)
if cls is None or cls == self.managed_class:
result = cls.__new__(cls)
result.update(super()._validate(
d, allow_extra=allow_extra, strip_extra=strip_extra))
validated_data = super()._validate(
d, allow_extra=allow_extra, strip_extra=strip_extra)
result.update(validated_data)
# Inject encryption functions into nested EncryptedObject instances
self._inject_encryption_funcs(result, cls)
return result
return cls.m.make(
d, allow_extra=allow_extra, strip_extra=strip_extra)

def _inject_encryption_funcs(self, obj, cls):
"""Recursively inject encryption functions into nested EncryptedObject instances."""
from ming.encryption import EncryptedObject, EncryptedMixin

if not issubclass(cls, EncryptedMixin):
return

for key, value in obj.items():
if isinstance(value, EncryptedObject):
# Inject encryption functions from the document class
object.__setattr__(value, '_encr_func', cls.encr)
object.__setattr__(value, '_decr_func', cls.decr)
elif isinstance(value, dict):
# Recursively process nested dicts
self._inject_encryption_funcs(value, cls)
elif isinstance(value, list):
# Process lists of dicts
for item in value:
if isinstance(item, (dict, EncryptedObject)):
self._inject_encryption_funcs(item, cls)

def set_polymorphic(self, field, registry, identity):
"""Configure polymorphic behaviour (except for ``.managed_class``).

Expand Down
60 changes: 60 additions & 0 deletions ming/tests/test_encryption.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,66 @@ class __mongometa__:
self.assertEqual(doc.name, None)
self.assertEqual(doc.name_encrypted, None)

def test_nested_decrypted_field(self):
"""Test DecryptedField inside a nested dict field."""
class TestDocNested(Document):
class __mongometa__:
name='test_doc_nested'
session = ming.Session.by_name('test_db')
_id = Field(S.Anything)
profile = Field(dict(
first_name=DecryptedField(str, 'first_name_encrypted'),
first_name_encrypted=S.Binary,
last_name=DecryptedField(str, 'last_name_encrypted'),
last_name_encrypted=S.Binary,
age=int # non-encrypted field
))

# Create document using make_encr with unencrypted field names
doc = TestDocNested.make_encr(dict(
_id=1,
profile=dict(
first_name='Jerome',
last_name='Smith',
age=30
)
))
doc.m.save()

# Test reading decrypted values
self.assertEqual(doc.profile.first_name, 'Jerome')
self.assertEqual(doc.profile.last_name, 'Smith')
self.assertEqual(doc.profile.age, 30)

# Test that encrypted values are stored correctly
self.assertIsInstance(doc.profile.first_name_encrypted, bytes)
self.assertEqual(doc.profile.first_name_encrypted, TestDocNested.encr('Jerome'))

# Verify that only encrypted fields are in the underlying dict (not the virtual decrypted fields)
self.assertIn('first_name_encrypted', doc.profile)
self.assertIn('last_name_encrypted', doc.profile)
self.assertNotIn('first_name', dict.keys(doc.profile))
self.assertNotIn('last_name', dict.keys(doc.profile))

# Test setting decrypted values
doc.profile.first_name = 'Jane'
doc.m.save()
self.assertEqual(doc.profile.first_name, 'Jane')
self.assertEqual(doc.profile.first_name_encrypted, TestDocNested.encr('Jane'))

# Test using dict-style access
doc.profile['last_name'] = 'Doe'
doc.m.save()
self.assertEqual(doc.profile['last_name'], 'Doe')
self.assertEqual(doc.profile['last_name_encrypted'], TestDocNested.encr('Doe'))

# Test setting None
doc.profile.first_name = None
doc.m.save()
self.assertEqual(doc.profile.first_name, None)
self.assertEqual(doc.profile.first_name_encrypted, None)


class TestDocumentEncryptionMimAutoSettings(TestDocumentEncryption):
def setUp(self):
# replace super() NOT using it
Expand Down