Skip to content

Commit c087fe7

Browse files
authored
Merge pull request #45 from nathanjshaffer/dataclass_support
Added support for JSON serialization via dataclass module
2 parents 7146a82 + 6736447 commit c087fe7

File tree

2 files changed

+31
-5
lines changed

2 files changed

+31
-5
lines changed

sqlacodegen/codegen.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@
3232

3333
_flask_prepend = 'db.'
3434

35+
_dataclass = False
36+
3537

3638
class _DummyInflectEngine(object):
3739
def singular_noun(self, noun):
@@ -313,7 +315,7 @@ def render(self):
313315
class ModelClass(Model):
314316
parent_name = 'Base'
315317

316-
def __init__(self, table, association_tables, inflect_engine, detect_joined):
318+
def __init__(self, table, association_tables, inflect_engine, detect_joined, collector):
317319
super(ModelClass, self).__init__(table)
318320
self.name = self._tablename_to_classname(table.name, inflect_engine)
319321
self.children = []
@@ -322,6 +324,10 @@ def __init__(self, table, association_tables, inflect_engine, detect_joined):
322324
# Assign attribute names for columns
323325
for column in table.columns:
324326
self._add_attribute(column.name, column)
327+
if _dataclass:
328+
if column.type.python_type.__module__ != 'builtins':
329+
collector.add_literal_import(column.type.python_type.__module__, column.type.python_type.__name__)
330+
325331

326332
# Add many-to-one relationships
327333
pk_column_names = set(col.name for col in table.primary_key.columns)
@@ -368,7 +374,13 @@ def add_imports(self, collector):
368374
child.add_imports(collector)
369375

370376
def render(self):
377+
global _dataclass
378+
371379
text = 'class {0}({1}):\n'.format(self.name, self.parent_name)
380+
381+
if _dataclass:
382+
text = '@dataclass\n' + text
383+
372384
text += ' __tablename__ = {0!r}\n'.format(self.table.name)
373385

374386
# Render constraints and indexes as __table_args__
@@ -403,6 +415,9 @@ def render(self):
403415
for attr, column in self.attributes.items():
404416
if isinstance(column, Column):
405417
show_name = attr != column.name
418+
if _dataclass:
419+
text += ' ' + attr + ' : ' + column.type.python_type.__name__ + '\n'
420+
406421
text += ' {0} = {1}\n'.format(attr, _render_column(column, show_name))
407422

408423
# Render relationships
@@ -536,7 +551,7 @@ class CodeGenerator(object):
536551

537552
def __init__(self, metadata, noindexes=False, noconstraints=False,
538553
nojoined=False, noinflect=False, nobackrefs=False,
539-
flask=False, ignore_cols=None, noclasses=False, nocomments=False, notables=False):
554+
flask=False, ignore_cols=None, noclasses=False, nocomments=False, notables=False, dataclass=False):
540555
super(CodeGenerator, self).__init__()
541556

542557
if noinflect:
@@ -554,6 +569,11 @@ def __init__(self, metadata, noindexes=False, noconstraints=False,
554569
_flask_prepend = ''
555570

556571
self.nocomments = nocomments
572+
573+
self.dataclass = dataclass
574+
if self.dataclass:
575+
global _dataclass
576+
_dataclass = True
557577

558578
# Pick association tables from the metadata into their own set, don't process them normally
559579
links = defaultdict(lambda: [])
@@ -612,13 +632,13 @@ def __init__(self, metadata, noindexes=False, noconstraints=False,
612632

613633
# Only generate classes when notables is set to True
614634
if notables:
615-
model = ModelClass(table, links[table.name], inflect_engine, not nojoined)
635+
model = ModelClass(table, links[table.name], inflect_engine, not nojoined, self.collector)
616636
classes[model.name] = model
617637
elif not table.primary_key or table.name in association_tables or noclasses:
618638
# Only form model classes for tables that have a primary key and are not association tables
619639
model = ModelTable(table)
620640
elif not noclasses:
621-
model = ModelClass(table, links[table.name], inflect_engine, not nojoined)
641+
model = ModelClass(table, links[table.name], inflect_engine, not nojoined, self.collector)
622642
classes[model.name] = model
623643

624644
self.models.append(model)
@@ -654,8 +674,13 @@ def __init__(self, metadata, noindexes=False, noconstraints=False,
654674
else:
655675
self.collector.add_literal_import('sqlalchemy.ext.declarative', 'declarative_base')
656676
self.collector.add_literal_import('sqlalchemy', 'MetaData')
677+
678+
679+
if self.dataclass:
680+
self.collector.add_literal_import('dataclasses', 'dataclass')
657681

658682
def render(self, outfile=sys.stdout):
683+
659684
print(self.header, file=outfile)
660685

661686
# Render the collected imports

sqlacodegen/main.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def main():
4040
parser.add_argument('--flask', action='store_true', help="use Flask-SQLAlchemy columns")
4141
parser.add_argument('--ignore-cols', help="Don't check foreign key constraints on specified columns (comma-separated)")
4242
parser.add_argument('--nocomments', action='store_true', help="don't render column comments")
43+
parser.add_argument('--dataclass', action='store_true', help="add dataclass decorators for JSON serialization")
4344
args = parser.parse_args()
4445

4546
if args.version:
@@ -62,7 +63,7 @@ def main():
6263
outfile = codecs.open(args.outfile, 'w', encoding='utf-8') if args.outfile else sys.stdout
6364
generator = CodeGenerator(metadata, args.noindexes, args.noconstraints,
6465
args.nojoined, args.noinflect, args.nobackrefs,
65-
args.flask, ignore_cols, args.noclasses, args.nocomments, args.notables)
66+
args.flask, ignore_cols, args.noclasses, args.nocomments, args.notables, args.dataclass)
6667
generator.render(outfile)
6768

6869

0 commit comments

Comments
 (0)