From e06dbfefef0ef89003963372fed65b2568638e6f Mon Sep 17 00:00:00 2001 From: Charlie Tonneslan Date: Sun, 17 May 2026 13:34:15 -0400 Subject: [PATCH] Skip XML comments and empty sequences in TrainingData iteration The XML parser used to feed comments and empty nodes straight into trainModel, which then crashed when it tried to iterate over the tokens. Skip both during __iter__ so callers can keep using comments for grouping examples (and unintentional empty sequences don't take the whole training run down with them). Closes #54. Signed-off-by: Charlie Tonneslan --- parserator/data_prep_utils.py | 7 +++++++ tests/test_xml.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 35 insertions(+) diff --git a/parserator/data_prep_utils.py b/parserator/data_prep_utils.py index 6f36dd1..0d5b635 100644 --- a/parserator/data_prep_utils.py +++ b/parserator/data_prep_utils.py @@ -51,6 +51,13 @@ def _strip_formatting(self, xml): def __iter__(self): for sequence_xml in self.xml: + # Skip XML comments (handy for grouping examples in the same + # file) and sequences that ended up with no children, both of + # which would otherwise blow up trainModel downstream. + if isinstance(sequence_xml, etree._Comment): + continue + if len(sequence_xml) == 0: + continue raw_text = etree.tostring(sequence_xml, method="text", encoding="unicode") yield raw_text, self._xml_to_sequence(sequence_xml) diff --git a/tests/test_xml.py b/tests/test_xml.py index c8b23bd..0e4dff2 100644 --- a/tests/test_xml.py +++ b/tests/test_xml.py @@ -51,5 +51,33 @@ def XMLequals(self, labeled_sequence, xml): assert correct_xml == generated_xml +class TestTrainingDataIter(unittest.TestCase): + def _td(self, xml_str): + return data_prep_utils.TrainingData(xml=etree.fromstring(xml_str)) + + def test_skips_comments(self): + # Comments are useful for grouping examples in the same file; they + # used to leak into __iter__ and crash trainModel. + td = self._td( + "" + "" + "a" + "" + "b" + "" + ) + assert [raw for raw, _ in td] == ["a", "b"] + + def test_skips_empty_sequences(self): + td = self._td( + "" + "a" + "" + "b" + "" + ) + assert [raw for raw, _ in td] == ["a", "b"] + + if __name__ == "__main__": unittest.main()