diff --git a/parserator/data_prep_utils.py b/parserator/data_prep_utils.py index 6f36dd1..0d5b635 100644 --- a/parserator/data_prep_utils.py +++ b/parserator/data_prep_utils.py @@ -51,6 +51,13 @@ def _strip_formatting(self, xml): def __iter__(self): for sequence_xml in self.xml: + # Skip XML comments (handy for grouping examples in the same + # file) and sequences that ended up with no children, both of + # which would otherwise blow up trainModel downstream. + if isinstance(sequence_xml, etree._Comment): + continue + if len(sequence_xml) == 0: + continue raw_text = etree.tostring(sequence_xml, method="text", encoding="unicode") yield raw_text, self._xml_to_sequence(sequence_xml) diff --git a/tests/test_xml.py b/tests/test_xml.py index c8b23bd..0e4dff2 100644 --- a/tests/test_xml.py +++ b/tests/test_xml.py @@ -51,5 +51,33 @@ def XMLequals(self, labeled_sequence, xml): assert correct_xml == generated_xml +class TestTrainingDataIter(unittest.TestCase): + def _td(self, xml_str): + return data_prep_utils.TrainingData(xml=etree.fromstring(xml_str)) + + def test_skips_comments(self): + # Comments are useful for grouping examples in the same file; they + # used to leak into __iter__ and crash trainModel. + td = self._td( + "" + "" + "a" + "" + "b" + "" + ) + assert [raw for raw, _ in td] == ["a", "b"] + + def test_skips_empty_sequences(self): + td = self._td( + "" + "a" + "" + "b" + "" + ) + assert [raw for raw, _ in td] == ["a", "b"] + + if __name__ == "__main__": unittest.main()