|
4 | 4 | import json |
5 | 5 | from typing import Optional, Any |
6 | 6 | from collections import Counter |
| 7 | +from contextlib import contextmanager |
7 | 8 |
|
8 | 9 | from medcat import cat |
9 | 10 | from medcat.data.model_card import ModelCard |
|
18 | 19 | from medcat.components.addons.meta_cat import MetaCATAddon |
19 | 20 | from medcat.utils.defaults import AVOID_LEGACY_CONVERSION_ENVIRON |
20 | 21 | from medcat.utils.defaults import LegacyConversionDisabledError |
| 22 | +from medcat.utils.config_utils import temp_changed_config |
21 | 23 |
|
22 | 24 | import unittest |
23 | 25 | import tempfile |
@@ -222,6 +224,60 @@ def test_inference_works(self): |
222 | 224 | with self.subTest(f"{nr}"): |
223 | 225 | ConvertedFunctionalityTests.assert_has_ent(ent) |
224 | 226 |
|
| 227 | + @classmethod |
| 228 | + @contextmanager |
| 229 | + def _faster_spacy_inference(cls): |
| 230 | + with temp_changed_config( |
| 231 | + cls.model.config.general.nlp, |
| 232 | + "faster_spacy_tokenization", |
| 233 | + True |
| 234 | + ): |
| 235 | + with temp_changed_config( |
| 236 | + cls.model.config.general.nlp, |
| 237 | + "modelname", |
| 238 | + "en_core_web_md" |
| 239 | + ): |
| 240 | + cls.model._recreate_pipe() |
| 241 | + yield |
| 242 | + cls.model._recreate_pipe() |
| 243 | + |
| 244 | + def _is_spacy_model(self): |
| 245 | + if self.model.config.general.nlp.provider != "spacy": |
| 246 | + raise unittest.SkipTest("Only applicable for spacy models") |
| 247 | + |
| 248 | + def test_default_spacy_runs_pipe(self): |
| 249 | + self._is_spacy_model() |
| 250 | + self.assertFalse(self.model.pipe._tokenizer._avoid_pipe) |
| 251 | + |
| 252 | + def test_faster_spacy_inference_is_set(self): |
| 253 | + self._is_spacy_model() |
| 254 | + with self._faster_spacy_inference(): |
| 255 | + self.assertTrue(self.model.pipe._tokenizer._avoid_pipe) |
| 256 | + |
| 257 | + def test_faster_spacy_inference_works(self): |
| 258 | + self._is_spacy_model() |
| 259 | + with self._faster_spacy_inference(): |
| 260 | + ents = self.model.get_entities( |
| 261 | + ConvertedFunctionalityTests.TEXT)['entities'] |
| 262 | + self.assertTrue(ents) |
| 263 | + for nr, ent in enumerate(ents.values()): |
| 264 | + with self.subTest(f"{nr}"): |
| 265 | + ConvertedFunctionalityTests.assert_has_ent(ent) |
| 266 | + |
| 267 | + def test_faster_spacy_inference_is_used(self): |
| 268 | + self._is_spacy_model() |
| 269 | + with self._faster_spacy_inference(): |
| 270 | + with unittest.mock.patch.object( |
| 271 | + self.model.pipe._tokenizer._nlp, |
| 272 | + '__call__') as dunder_call_mock: |
| 273 | + with unittest.mock.patch.object( |
| 274 | + self.model.pipe._tokenizer._nlp, |
| 275 | + 'make_doc') as make_doc_mock: |
| 276 | + self.model.get_entities( |
| 277 | + ConvertedFunctionalityTests.TEXT) |
| 278 | + dunder_call_mock.assert_not_called() |
| 279 | + make_doc_mock.assert_called() |
| 280 | + |
225 | 281 | def test_entities_in_correct_order(self): |
226 | 282 | # NOTE: the issue wouldn't show up with smaller amount of text |
227 | 283 | doc = self.model(ConvertedFunctionalityTests.TEXT * 3) |
|
0 commit comments