@@ -487,7 +487,7 @@ def visitField(self, sum):
487487
488488class Obj2ModPrototypeVisitor (PickleVisitor ):
489489 def visitProduct (self , prod , name ):
490- code = "static int obj2ast_%s(struct ast_state *state, PyObject* obj, %s* out, PyArena* arena);"
490+ code = "static int obj2ast_%s(struct ast_state *state, PyObject* obj, %s* out, const char* field, PyArena* arena);"
491491 self .emit (code % (name , get_c_type (name )), 0 )
492492
493493 visitSum = visitProduct
@@ -511,7 +511,7 @@ def recursive_call(self, node, level):
511511 def funcHeader (self , name ):
512512 ctype = get_c_type (name )
513513 self .emit ("int" , 0 )
514- self .emit ("obj2ast_%s(struct ast_state *state, PyObject* obj, %s* out, PyArena* arena)" % (name , ctype ), 0 )
514+ self .emit ("obj2ast_%s(struct ast_state *state, PyObject* obj, %s* out, const char* field, PyArena* arena)" % (name , ctype ), 0 )
515515 self .emit ("{" , 0 )
516516 self .emit ("int isinstance;" , 1 )
517517 self .emit ("" , 0 )
@@ -547,6 +547,18 @@ def simpleSum(self, sum, name):
547547 def buildArgs (self , fields ):
548548 return ", " .join (fields + ["arena" ])
549549
550+ def typeCheck (self , name ):
551+ self .emit ("tp = state->%s_type;" % name , 1 )
552+ self .emit ("isinstance = PyObject_IsInstance(obj, tp);" , 1 )
553+ self .emit ("if (isinstance == -1) {" , 1 )
554+ self .emit ("return 1;" , 2 )
555+ self .emit ("}" , 1 )
556+ self .emit ("if (!isinstance && field != NULL) {" , 1 )
557+ error = "field '%%s' was expecting node of type '%s', got '%%s'" % name
558+ self .emit ("PyErr_Format(PyExc_TypeError, \" %s\" , field, _PyType_Name(Py_TYPE(obj)));" % error , 2 , reflow = False )
559+ self .emit ("return 1;" , 2 )
560+ self .emit ("}" , 1 )
561+
550562 def complexSum (self , sum , name ):
551563 self .funcHeader (name )
552564 self .emit ("PyObject *tmp = NULL;" , 1 )
@@ -559,6 +571,7 @@ def complexSum(self, sum, name):
559571 self .emit ("*out = NULL;" , 2 )
560572 self .emit ("return 0;" , 2 )
561573 self .emit ("}" , 1 )
574+ self .typeCheck (name )
562575 for a in sum .attributes :
563576 self .visitField (a , name , sum = sum , depth = 1 )
564577 for t in sum .types :
@@ -593,7 +606,7 @@ def visitSum(self, sum, name):
593606 def visitProduct (self , prod , name ):
594607 ctype = get_c_type (name )
595608 self .emit ("int" , 0 )
596- self .emit ("obj2ast_%s(struct ast_state *state, PyObject* obj, %s* out, PyArena* arena)" % (name , ctype ), 0 )
609+ self .emit ("obj2ast_%s(struct ast_state *state, PyObject* obj, %s* out, const char* field, PyArena* arena)" % (name , ctype ), 0 )
597610 self .emit ("{" , 0 )
598611 self .emit ("PyObject* tmp = NULL;" , 1 )
599612 for f in prod .fields :
@@ -694,8 +707,8 @@ def visitField(self, field, name, sum=None, prod=None, depth=0):
694707 self .emit ("%s val;" % ctype , depth + 2 )
695708 self .emit ("PyObject *tmp2 = Py_NewRef(PyList_GET_ITEM(tmp, i));" , depth + 2 )
696709 with self .recursive_call (name , depth + 2 ):
697- self .emit ("res = obj2ast_%s(state, tmp2, &val, arena);" %
698- field .type , depth + 2 , reflow = False )
710+ self .emit ("res = obj2ast_%s(state, tmp2, &val, \" %s \" , arena);" %
711+ ( field .type , field . name ) , depth + 2 , reflow = False )
699712 self .emit ("Py_DECREF(tmp2);" , depth + 2 )
700713 self .emit ("if (res != 0) goto failed;" , depth + 2 )
701714 self .emit ("if (len != PyList_GET_SIZE(tmp)) {" , depth + 2 )
@@ -709,8 +722,8 @@ def visitField(self, field, name, sum=None, prod=None, depth=0):
709722 self .emit ("}" , depth + 1 )
710723 else :
711724 with self .recursive_call (name , depth + 1 ):
712- self .emit ("res = obj2ast_%s(state, tmp, &%s, arena);" %
713- (field .type , field .name ), depth + 1 )
725+ self .emit ("res = obj2ast_%s(state, tmp, &%s, \" %s \" , arena);" %
726+ (field .type , field .name , field . name ), depth + 1 )
714727 self .emit ("if (res != 0) goto failed;" , depth + 1 )
715728
716729 self .emit ("Py_CLEAR(tmp);" , depth + 1 )
@@ -1701,7 +1714,9 @@ def visitModule(self, mod):
17011714
17021715/* Conversion Python -> AST */
17031716
1704- static int obj2ast_object(struct ast_state *Py_UNUSED(state), PyObject* obj, PyObject** out, PyArena* arena)
1717+ static int obj2ast_object(struct ast_state *Py_UNUSED(state), PyObject* obj,
1718+ PyObject** out,
1719+ const char* Py_UNUSED(field), PyArena* arena)
17051720{
17061721 if (obj == Py_None)
17071722 obj = NULL;
@@ -1718,7 +1733,9 @@ def visitModule(self, mod):
17181733 return 0;
17191734}
17201735
1721- static int obj2ast_constant(struct ast_state *Py_UNUSED(state), PyObject* obj, PyObject** out, PyArena* arena)
1736+ static int obj2ast_constant(struct ast_state *Py_UNUSED(state), PyObject* obj,
1737+ PyObject** out,
1738+ const char* Py_UNUSED(field), PyArena* arena)
17221739{
17231740 if (_PyArena_AddPyObject(arena, obj) < 0) {
17241741 *out = NULL;
@@ -1728,29 +1745,29 @@ def visitModule(self, mod):
17281745 return 0;
17291746}
17301747
1731- static int obj2ast_identifier(struct ast_state *state, PyObject* obj, PyObject** out, PyArena* arena)
1748+ static int obj2ast_identifier(struct ast_state *state, PyObject* obj, PyObject** out, const char* field, PyArena* arena)
17321749{
17331750 if (!PyUnicode_CheckExact(obj) && obj != Py_None) {
1734- PyErr_SetString (PyExc_TypeError, "AST identifier must be of type str" );
1751+ PyErr_Format (PyExc_TypeError, "field '%s' was expecting a string object", field );
17351752 return -1;
17361753 }
1737- return obj2ast_object(state, obj, out, arena);
1754+ return obj2ast_object(state, obj, out, field, arena);
17381755}
17391756
1740- static int obj2ast_string(struct ast_state *state, PyObject* obj, PyObject** out, PyArena* arena)
1757+ static int obj2ast_string(struct ast_state *state, PyObject* obj, PyObject** out, const char* field, PyArena* arena)
17411758{
17421759 if (!PyUnicode_CheckExact(obj) && !PyBytes_CheckExact(obj)) {
1743- PyErr_SetString (PyExc_TypeError, "AST string must be of type str" );
1760+ PyErr_Format (PyExc_TypeError, "field '%s' was expecting a string or bytes object", field );
17441761 return -1;
17451762 }
1746- return obj2ast_object(state, obj, out, arena);
1763+ return obj2ast_object(state, obj, out, field, arena);
17471764}
17481765
1749- static int obj2ast_int(struct ast_state* Py_UNUSED(state), PyObject* obj, int* out, PyArena* arena)
1766+ static int obj2ast_int(struct ast_state* Py_UNUSED(state), PyObject* obj, int* out, const char* field, PyArena* arena)
17501767{
17511768 int i;
17521769 if (!PyLong_Check(obj)) {
1753- PyErr_Format(PyExc_ValueError, "invalid integer value: %R", obj);
1770+ PyErr_Format(PyExc_ValueError, "field \\ "%s \\ " got an invalid integer value: %R", field , obj);
17541771 return -1;
17551772 }
17561773
@@ -2150,7 +2167,7 @@ class PartingShots(StaticVisitor):
21502167 }
21512168
21522169 mod_ty res = NULL;
2153- if (obj2ast_mod(state, ast, &res, arena) != 0)
2170+ if (obj2ast_mod(state, ast, &res, NULL, arena) != 0)
21542171 return NULL;
21552172 else
21562173 return res;
0 commit comments