3232
3333_flask_prepend = 'db.'
3434
35+ _dataclass = False
36+
3537
3638class _DummyInflectEngine (object ):
3739 def singular_noun (self , noun ):
@@ -313,7 +315,7 @@ def render(self):
313315class ModelClass (Model ):
314316 parent_name = 'Base'
315317
316- def __init__ (self , table , association_tables , inflect_engine , detect_joined ):
318+ def __init__ (self , table , association_tables , inflect_engine , detect_joined , collector ):
317319 super (ModelClass , self ).__init__ (table )
318320 self .name = self ._tablename_to_classname (table .name , inflect_engine )
319321 self .children = []
@@ -322,6 +324,10 @@ def __init__(self, table, association_tables, inflect_engine, detect_joined):
322324 # Assign attribute names for columns
323325 for column in table .columns :
324326 self ._add_attribute (column .name , column )
327+ if _dataclass :
328+ if column .type .python_type .__module__ != 'builtins' :
329+ collector .add_literal_import (column .type .python_type .__module__ , column .type .python_type .__name__ )
330+
325331
326332 # Add many-to-one relationships
327333 pk_column_names = set (col .name for col in table .primary_key .columns )
@@ -368,7 +374,13 @@ def add_imports(self, collector):
368374 child .add_imports (collector )
369375
370376 def render (self ):
377+ global _dataclass
378+
371379 text = 'class {0}({1}):\n ' .format (self .name , self .parent_name )
380+
381+ if _dataclass :
382+ text = '@dataclass\n ' + text
383+
372384 text += ' __tablename__ = {0!r}\n ' .format (self .table .name )
373385
374386 # Render constraints and indexes as __table_args__
@@ -403,6 +415,9 @@ def render(self):
403415 for attr , column in self .attributes .items ():
404416 if isinstance (column , Column ):
405417 show_name = attr != column .name
418+ if _dataclass :
419+ text += ' ' + attr + ' : ' + column .type .python_type .__name__ + '\n '
420+
406421 text += ' {0} = {1}\n ' .format (attr , _render_column (column , show_name ))
407422
408423 # Render relationships
@@ -536,7 +551,7 @@ class CodeGenerator(object):
536551
537552 def __init__ (self , metadata , noindexes = False , noconstraints = False ,
538553 nojoined = False , noinflect = False , nobackrefs = False ,
539- flask = False , ignore_cols = None , noclasses = False , nocomments = False , notables = False ):
554+ flask = False , ignore_cols = None , noclasses = False , nocomments = False , notables = False , dataclass = False ):
540555 super (CodeGenerator , self ).__init__ ()
541556
542557 if noinflect :
@@ -554,6 +569,11 @@ def __init__(self, metadata, noindexes=False, noconstraints=False,
554569 _flask_prepend = ''
555570
556571 self .nocomments = nocomments
572+
573+ self .dataclass = dataclass
574+ if self .dataclass :
575+ global _dataclass
576+ _dataclass = True
557577
558578 # Pick association tables from the metadata into their own set, don't process them normally
559579 links = defaultdict (lambda : [])
@@ -612,13 +632,13 @@ def __init__(self, metadata, noindexes=False, noconstraints=False,
612632
613633 # Only generate classes when notables is set to True
614634 if notables :
615- model = ModelClass (table , links [table .name ], inflect_engine , not nojoined )
635+ model = ModelClass (table , links [table .name ], inflect_engine , not nojoined , self . collector )
616636 classes [model .name ] = model
617637 elif not table .primary_key or table .name in association_tables or noclasses :
618638 # Only form model classes for tables that have a primary key and are not association tables
619639 model = ModelTable (table )
620640 elif not noclasses :
621- model = ModelClass (table , links [table .name ], inflect_engine , not nojoined )
641+ model = ModelClass (table , links [table .name ], inflect_engine , not nojoined , self . collector )
622642 classes [model .name ] = model
623643
624644 self .models .append (model )
@@ -654,8 +674,13 @@ def __init__(self, metadata, noindexes=False, noconstraints=False,
654674 else :
655675 self .collector .add_literal_import ('sqlalchemy.ext.declarative' , 'declarative_base' )
656676 self .collector .add_literal_import ('sqlalchemy' , 'MetaData' )
677+
678+
679+ if self .dataclass :
680+ self .collector .add_literal_import ('dataclasses' , 'dataclass' )
657681
658682 def render (self , outfile = sys .stdout ):
683+
659684 print (self .header , file = outfile )
660685
661686 # Render the collected imports
0 commit comments