Skip to content

Commit b9a0e0e

Browse files
fix: More work on typing support; add black and ruff
1 parent bc57439 commit b9a0e0e

File tree

12 files changed

+433
-60
lines changed

12 files changed

+433
-60
lines changed

.github/workflows/test.yml

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,15 +32,24 @@ jobs:
3232
- 27017:27017
3333
steps:
3434
- uses: actions/checkout@v3
35+
- name: Install poetry
36+
run: pipx install poetry
3537
- name: Set up Python ${{ matrix.python-version }}
3638
uses: actions/setup-python@v4
3739
with:
3840
python-version: ${{ matrix.python }}
41+
cache: 'poetry'
42+
- name: Cache virtualenv
43+
uses: actions/cache@v3
44+
with:
45+
key: venv-${{ runner.os }}-${{ steps.setup_python.outputs.python-version}}-${{ hashFiles('poetry.lock') }}
46+
path: .venv
3947
- name: Set up env
4048
run: |
41-
python -m pip install -U -q poetry pip
4249
poetry install -q
4350
poetry run pip install -q "${{ matrix.django }}"
4451
- name: Run tests
4552
run: |
53+
poetry run ruff .
54+
poetry run black --check .
4655
poetry run python -m pytest

Makefile

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,8 @@ publish:
1212

1313
test:
1414
poetry run python -m pytest
15+
16+
codegen:
17+
python codegen.py
18+
black django_mongoengine/fields/__init__.py
19+
ruff django_mongoengine/ --fix # It doesn't work with filename.

codegen.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
def generate_fields():
2+
"""
3+
Typing support cannot handle monkey-patching at runtime, so we need to generate fields explicitly.
4+
"""
5+
from mongoengine import fields
6+
from django_mongoengine.fields import djangoflavor as mixins
7+
8+
fields_code = str(_fields)
9+
for fname in fields.__all__:
10+
mixin_name = fname if hasattr(mixins, fname) else "DjangoField"
11+
fields_code += f"class {fname}(_mixins.{mixin_name}, _fields.{fname}):\n pass\n"
12+
13+
return fields_code
14+
15+
16+
_fields = """
17+
from mongoengine import fields as _fields
18+
from . import djangoflavor as _mixins
19+
from django_mongoengine.utils.monkey import patch_mongoengine_field
20+
21+
for f in ["StringField", "ObjectIdField"]:
22+
patch_mongoengine_field(f)
23+
24+
"""
25+
26+
if __name__ == "__main__":
27+
fname = "django_mongoengine/fields/__init__.py"
28+
# This content required, because otherwise mixins import does not work.
29+
open(fname, "w").write("from mongoengine.fields import *")
30+
content = generate_fields()
31+
open(fname, "w").write(content)

django_mongoengine/document.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from __future__ import annotations
33

44
from functools import partial
5-
from typing import TYPE_CHECKING, Any
5+
from typing import TYPE_CHECKING
66

77
from bson.objectid import ObjectId
88
from django.db.models import Model
@@ -11,11 +11,15 @@
1111
from mongoengine import document as me
1212
from mongoengine.base import metaclasses as mtc
1313
from mongoengine.errors import FieldDoesNotExist
14+
from typing_extensions import Self
1415

1516
from .fields import ObjectIdField
1617
from .forms.document_options import DocumentMetaWrapper
1718
from .queryset import QuerySetManager
1819

20+
if TYPE_CHECKING:
21+
from mongoengine.fields import StringField
22+
1923
# TopLevelDocumentMetaclass is using ObjectIdField to create default pk field,
2024
# if one's not set explicitly.
2125
# We need to know it's not editable and auto_created.
@@ -43,11 +47,11 @@ def __new__(cls, name, bases, attrs):
4347

4448

4549
class DjangoFlavor:
46-
id: Any
47-
objects: Any = QuerySetManager()
48-
_meta: DocumentMetaWrapper
49-
_default_manager: Any = QuerySetManager()
50+
id: StringField
51+
objects = QuerySetManager[Self]()
52+
_default_manager = QuerySetManager[Self]()
5053
_get_pk_val = Model.__dict__["_get_pk_val"]
54+
_meta: DocumentMetaWrapper
5155
DoesNotExist: type[DoesNotExist]
5256

5357
def __init__(self, *args, **kwargs):
@@ -115,7 +119,7 @@ class DynamicDocument(DjangoFlavor, me.DynamicDocument):
115119
...
116120

117121
class EmbeddedDocument(DjangoFlavor, me.EmbeddedDocument):
118-
...
122+
_instance: Document
119123

120124
class DynamicEmbeddedDocument(DjangoFlavor, me.DynamicEmbeddedDocument):
121125
...
Lines changed: 171 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,43 +1,182 @@
1-
from . import djangoflavor
1+
from mongoengine import fields as _fields
2+
from . import djangoflavor as _mixins
3+
from django_mongoengine.utils.monkey import patch_mongoengine_field
24

5+
for f in ["StringField", "ObjectIdField"]:
6+
patch_mongoengine_field(f)
37

4-
def init_module():
5-
"""
6-
Create classes with Django-flavor mixins,
7-
use DjangoField mixin as default
8-
"""
9-
import sys
108

11-
from mongoengine import fields
9+
class StringField(_mixins.StringField, _fields.StringField):
10+
pass
1211

13-
current_module = sys.modules[__name__]
14-
current_module.__all__ = fields.__all__
1512

16-
for name in fields.__all__:
17-
fieldcls = getattr(fields, name)
18-
mixin = getattr(djangoflavor, name, djangoflavor.DjangoField)
19-
setattr(
20-
current_module,
21-
name,
22-
type(name, (mixin, fieldcls), {}),
23-
)
13+
class URLField(_mixins.URLField, _fields.URLField):
14+
pass
2415

2516

26-
def patch_mongoengine_field(field_name):
27-
"""
28-
patch mongoengine.[field_name] for comparison support
29-
becouse it's required in django.forms.models.fields_for_model
30-
importing using mongoengine internal import cache
31-
"""
32-
from mongoengine import common
17+
class EmailField(_mixins.EmailField, _fields.EmailField):
18+
pass
3319

34-
field = common._import_class(field_name)
35-
for k in ["__eq__", "__lt__", "__hash__", "attname", "get_internal_type"]:
36-
if k not in field.__dict__:
37-
setattr(field, k, djangoflavor.DjangoField.__dict__[k])
3820

21+
class IntField(_mixins.IntField, _fields.IntField):
22+
pass
3923

40-
init_module()
4124

42-
for f in ["StringField", "ObjectIdField"]:
43-
patch_mongoengine_field(f)
25+
class LongField(_mixins.DjangoField, _fields.LongField):
26+
pass
27+
28+
29+
class FloatField(_mixins.FloatField, _fields.FloatField):
30+
pass
31+
32+
33+
class DecimalField(_mixins.DecimalField, _fields.DecimalField):
34+
pass
35+
36+
37+
class BooleanField(_mixins.BooleanField, _fields.BooleanField):
38+
pass
39+
40+
41+
class DateTimeField(_mixins.DateTimeField, _fields.DateTimeField):
42+
pass
43+
44+
45+
class DateField(_mixins.DjangoField, _fields.DateField):
46+
pass
47+
48+
49+
class ComplexDateTimeField(_mixins.DjangoField, _fields.ComplexDateTimeField):
50+
pass
51+
52+
53+
class EmbeddedDocumentField(_mixins.EmbeddedDocumentField, _fields.EmbeddedDocumentField):
54+
pass
55+
56+
57+
class ObjectIdField(_mixins.DjangoField, _fields.ObjectIdField):
58+
pass
59+
60+
61+
class GenericEmbeddedDocumentField(_mixins.DjangoField, _fields.GenericEmbeddedDocumentField):
62+
pass
63+
64+
65+
class DynamicField(_mixins.DjangoField, _fields.DynamicField):
66+
pass
67+
68+
69+
class ListField(_mixins.ListField, _fields.ListField):
70+
pass
71+
72+
73+
class SortedListField(_mixins.DjangoField, _fields.SortedListField):
74+
pass
75+
76+
77+
class EmbeddedDocumentListField(_mixins.DjangoField, _fields.EmbeddedDocumentListField):
78+
pass
79+
80+
81+
class DictField(_mixins.DictField, _fields.DictField):
82+
pass
83+
84+
85+
class MapField(_mixins.DjangoField, _fields.MapField):
86+
pass
87+
88+
89+
class ReferenceField(_mixins.ReferenceField, _fields.ReferenceField):
90+
pass
91+
92+
93+
class CachedReferenceField(_mixins.DjangoField, _fields.CachedReferenceField):
94+
pass
95+
96+
97+
class LazyReferenceField(_mixins.DjangoField, _fields.LazyReferenceField):
98+
pass
99+
100+
101+
class GenericLazyReferenceField(_mixins.DjangoField, _fields.GenericLazyReferenceField):
102+
pass
103+
104+
105+
class GenericReferenceField(_mixins.DjangoField, _fields.GenericReferenceField):
106+
pass
107+
108+
109+
class BinaryField(_mixins.DjangoField, _fields.BinaryField):
110+
pass
111+
112+
113+
class GridFSError(_mixins.DjangoField, _fields.GridFSError):
114+
pass
115+
116+
117+
class GridFSProxy(_mixins.DjangoField, _fields.GridFSProxy):
118+
pass
119+
120+
121+
class FileField(_mixins.FileField, _fields.FileField):
122+
pass
123+
124+
125+
class ImageGridFsProxy(_mixins.DjangoField, _fields.ImageGridFsProxy):
126+
pass
127+
128+
129+
class ImproperlyConfigured(_mixins.ImproperlyConfigured, _fields.ImproperlyConfigured):
130+
pass
131+
132+
133+
class ImageField(_mixins.ImageField, _fields.ImageField):
134+
pass
135+
136+
137+
class GeoPointField(_mixins.DjangoField, _fields.GeoPointField):
138+
pass
139+
140+
141+
class PointField(_mixins.DjangoField, _fields.PointField):
142+
pass
143+
144+
145+
class LineStringField(_mixins.DjangoField, _fields.LineStringField):
146+
pass
147+
148+
149+
class PolygonField(_mixins.DjangoField, _fields.PolygonField):
150+
pass
151+
152+
153+
class SequenceField(_mixins.DjangoField, _fields.SequenceField):
154+
pass
155+
156+
157+
class UUIDField(_mixins.DjangoField, _fields.UUIDField):
158+
pass
159+
160+
161+
class EnumField(_mixins.DjangoField, _fields.EnumField):
162+
pass
163+
164+
165+
class MultiPointField(_mixins.DjangoField, _fields.MultiPointField):
166+
pass
167+
168+
169+
class MultiLineStringField(_mixins.DjangoField, _fields.MultiLineStringField):
170+
pass
171+
172+
173+
class MultiPolygonField(_mixins.DjangoField, _fields.MultiPolygonField):
174+
pass
175+
176+
177+
class GeoJsonBaseField(_mixins.DjangoField, _fields.GeoJsonBaseField):
178+
pass
179+
180+
181+
class Decimal128Field(_mixins.DjangoField, _fields.Decimal128Field):
182+
pass

django_mongoengine/forms/fields.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def __init__(self, form, *args, **kwargs):
107107
kwargs['widget'] = EmbeddedFieldWidget(self.form.fields)
108108
kwargs['initial'] = [f.initial for f in self.form.fields.values()]
109109
kwargs['require_all_fields'] = False
110-
super().__init__(fields=tuple([f for f in self.form.fields.values()]), *args, **kwargs)
110+
super().__init__(fields=tuple(self.form.fields.values()), *args, **kwargs)
111111

112112
def bound_data(self, data, initial):
113113
return data

0 commit comments

Comments
 (0)