11# -*- coding: utf-8 -*-
22
3+ """This module is a CRUD interface between resource managers and the sqlalchemy ORM"""
4+
35from sqlalchemy .orm .exc import NoResultFound
46from sqlalchemy .orm .collections import InstrumentedList
57from sqlalchemy .inspection import inspect
8+ from sqlalchemy .orm import joinedload
9+ from marshmallow import class_registry
10+ from marshmallow .base import SchemaABC
611
7- from flask_rest_jsonapi . constants import DEFAULT_PAGE_SIZE
12+ from flask import current_app
813from flask_rest_jsonapi .data_layers .base import BaseDataLayer
914from flask_rest_jsonapi .exceptions import RelationNotFound , RelatedObjectNotFound , JsonApiException ,\
1015 InvalidSort , ObjectNotFound
1116from flask_rest_jsonapi .data_layers .filtering .alchemy import create_filters
12- from flask_rest_jsonapi .schema import get_relationships
17+ from flask_rest_jsonapi .schema import get_model_field , get_related_schema , get_relationships
1318
1419
1520class SqlalchemyDataLayer (BaseDataLayer ):
21+ """Sqlalchemy data layer"""
1622
1723 def __init__ (self , kwargs ):
24+ """Initialize an instance of SqlalchemyDataLayer
25+
26+ :param dict kwargs: initialization parameters of an SqlalchemyDataLayer instance
27+ """
1828 super (SqlalchemyDataLayer , self ).__init__ (kwargs )
1929
2030 if not hasattr (self , 'session' ):
@@ -34,7 +44,8 @@ def create_object(self, data, view_kwargs):
3444 self .before_create_object (data , view_kwargs )
3545
3646 relationship_fields = get_relationships (self .resource .schema )
37- obj = self .model (** {key : value for (key , value ) in data .items () if key not in relationship_fields })
47+ obj = self .model (** {get_model_field (self .resource .schema , key ): value
48+ for (key , value ) in data .items () if key not in relationship_fields })
3849 self .apply_relationships (data , obj )
3950
4051 self .session .add (obj )
@@ -93,11 +104,14 @@ def get_collection(self, qs, view_kwargs):
93104
94105 object_count = query .count ()
95106
107+ if getattr (self , 'eagerload_includes' , True ):
108+ query = self .eagerload_includes (query , qs )
109+
96110 query = self .paginate_query (query , qs .pagination )
97111
98112 collection = query .all ()
99113
100- self .after_get_collection (collection , qs , view_kwargs )
114+ collection = self .after_get_collection (collection , qs , view_kwargs )
101115
102116 return object_count , collection
103117
@@ -119,8 +133,8 @@ def update_object(self, obj, data, view_kwargs):
119133
120134 relationship_fields = get_relationships (self .resource .schema )
121135 for key , value in data .items ():
122- if hasattr (obj , key ) and key not in relationship_fields :
123- setattr (obj , key , value )
136+ if hasattr (obj , get_model_field ( self . resource . schema , key ) ) and key not in relationship_fields :
137+ setattr (obj , get_model_field ( self . resource . schema , key ) , value )
124138
125139 self .apply_relationships (data , obj )
126140
@@ -380,11 +394,13 @@ def apply_relationships(self, data, obj):
380394 :param DeclarativeMeta obj: the sqlalchemy object to plug relationships to
381395 :return boolean: True if relationship have changed else False
382396 """
397+ relationships_to_apply = []
383398 relationship_fields = get_relationships (self .resource .schema )
384399 for key , value in data .items ():
385400 if key in relationship_fields :
386- related_model = getattr (obj .__class__ , key ).property .mapper .class_
387- related_id_field = self .resource .schema ._declared_fields [relationship_fields [key ]].id_field
401+ related_model = getattr (obj .__class__ ,
402+ get_model_field (self .resource .schema , key )).property .mapper .class_
403+ related_id_field = self .resource .schema ._declared_fields [key ].id_field
388404
389405 if isinstance (value , list ):
390406 related_objects = []
@@ -393,14 +409,17 @@ def apply_relationships(self, data, obj):
393409 related_object = self .get_related_object (related_model , related_id_field , {'id' : identifier })
394410 related_objects .append (related_object )
395411
396- setattr ( obj , key , related_objects )
412+ relationships_to_apply . append ({ 'field' : key , 'value' : related_objects } )
397413 else :
398414 related_object = None
399415
400416 if value is not None :
401417 related_object = self .get_related_object (related_model , related_id_field , {'id' : value })
402418
403- setattr (obj , key , related_object )
419+ relationships_to_apply .append ({'field' : key , 'value' : related_object })
420+
421+ for relationship in relationships_to_apply :
422+ setattr (obj , get_model_field (self .resource .schema , relationship ['field' ]), relationship ['value' ])
404423
405424 def filter_query (self , query , filter_info , model ):
406425 """Filter query according to jsonapi 1.0
@@ -441,13 +460,49 @@ def paginate_query(self, query, paginate_info):
441460 if int (paginate_info .get ('size' , 1 )) == 0 :
442461 return query
443462
444- page_size = int (paginate_info .get ('size' , 0 )) or DEFAULT_PAGE_SIZE
463+ page_size = int (paginate_info .get ('size' , 0 )) or current_app . config [ 'PAGE_SIZE' ]
445464 query = query .limit (page_size )
446465 if paginate_info .get ('number' ):
447466 query = query .offset ((int (paginate_info ['number' ]) - 1 ) * page_size )
448467
449468 return query
450469
470+ def eagerload_includes (self , query , qs ):
471+ """Use eagerload feature of sqlalchemy to optimize data retrieval for include querystring parameter
472+
473+ :param Query query: sqlalchemy queryset
474+ :param QueryStringManager qs: a querystring manager to retrieve information from url
475+ :return Query: the query with includes eagerloaded
476+ """
477+ for include in qs .include :
478+ joinload_object = None
479+
480+ if '.' in include :
481+ current_schema = self .resource .schema
482+ for obj in include .split ('.' ):
483+ field = get_model_field (current_schema , obj )
484+
485+ if joinload_object is None :
486+ joinload_object = joinedload (field , innerjoin = True )
487+ else :
488+ joinload_object = joinload_object .joinedload (field , innerjoin = True )
489+
490+ related_schema_cls = get_related_schema (current_schema , obj )
491+
492+ if isinstance (related_schema_cls , SchemaABC ):
493+ related_schema_cls = related_schema_cls .__class__
494+ else :
495+ related_schema_cls = class_registry .get_class (related_schema_cls )
496+
497+ current_schema = related_schema_cls
498+ else :
499+ field = get_model_field (self .resource .schema , include )
500+ joinload_object = joinedload (field , innerjoin = True )
501+
502+ query = query .options (joinload_object )
503+
504+ return query
505+
451506 def query (self , view_kwargs ):
452507 """Construct the base query to retrieve wanted data
453508
@@ -502,7 +557,7 @@ def after_get_collection(self, collection, qs, view_kwargs):
502557 :param QueryStringManager qs: a querystring manager to retrieve information from url
503558 :param dict view_kwargs: kwargs from the resource view
504559 """
505- pass
560+ return collection
506561
507562 def before_update_object (self , obj , data , view_kwargs ):
508563 """Make checks or provide additional data before update object
0 commit comments