Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 76 additions & 48 deletions core/common/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from pydash import compact, get, has, set_
from sentence_transformers import CrossEncoder

from core.common import ERRBIT_LOGGER
from core.common.constants import ES_REQUEST_TIMEOUT
from core.common.utils import is_url_encoded_string

Expand Down Expand Up @@ -336,42 +337,43 @@ def __get_response(self, exact_count=True, load_fields=False):
return self._dsl_search, None, total


class VectorEmbed:
def __init__(self, model_name=None):
self.model_name = model_name or settings.LM_MODEL_NAME

def embed(self, txt):
if settings.EMBEDDING_SERVICE_URL:
return self._get_embedding_from_service(txt)
return self._get_embedding_locally(txt)

def _get_embedding_from_service(self, txt):
try:
import requests as req

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit (b): import requests as req here and again in _get_rerank_scores_from_service (~line 418). requests is already a top-level dependency — hoist a single import requests into the module import block and drop the per-call aliased imports.

response = req.post(
f'{settings.EMBEDDING_SERVICE_URL}/embeddings',
headers={'Authorization': f'Bearer {settings.INFINITY_API_KEY}'},
json={'model': self.model_name, 'input': str(txt)},
timeout=10
)
response.raise_for_status()
return response.json()['data'][0]['embedding']
except Exception as ex: # pylint: disable=broad-except
ERRBIT_LOGGER.log(ex)
return self._get_embedding_locally(txt)

def _get_embedding_locally(self, txt):
from sentence_transformers import SentenceTransformer
model = SentenceTransformer(self.model_name)
return list(model.encode(str(txt)))
Comment on lines +364 to +367

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#1 + #3 — cache the local model and return native floats. The fallback currently reloads a ~400MB SentenceTransformer on every embed() call (the old path used the preloaded settings.LM singleton), and returns list(np.ndarray) = np.float32 elements where the old call site relied on .tolist() to get JSON-safe native floats. Caching per model name makes a mid-indexing outage cheap instead of catastrophic, and .tolist() restores native floats for the ES query_vector/index payload:

Suggested change
def _get_embedding_locally(self, txt):
from sentence_transformers import SentenceTransformer
model = SentenceTransformer(self.model_name)
return list(model.encode(str(txt)))
_LOCAL_MODELS = {}
def _get_embedding_locally(self, txt):
model = self._LOCAL_MODELS.get(self.model_name)
if model is None:
from sentence_transformers import SentenceTransformer
model = self._LOCAL_MODELS[self.model_name] = SentenceTransformer(self.model_name)
return model.encode(str(txt)).tolist()



class Reranker:
ENCODERS = [
# Best and Fastest overall lightweight medical reranker
# Size: ~110M
# Speed: similar to MiniLM CrossEncoder
# Training: includes clinical, medical, question-answering datasets
# Output: positive similarity scores (not raw logits!)
# 0.6B params
# https://huggingface.co/BAAI/bge-reranker-v2-m3
"BAAI/bge-reranker-v2-m3",

# Model: jinhybr/OA-MedBERT-cross-encoder or similar
# Size: ~110M
# Domain: PubMed abstracts, biomedical QA
# Type: binary classifier (logits)
# Not huggin face model -- ???
# "jinhybr/OA-MedBERT-cross-encoder",

# Model: microsoft/BioLinkBERT-base
# Type: CrossEncoder
# Size: ~120M
# Domain: UMLS, PubMed, MeSH, SNOMED (closest to OCL)
# Not huggin face model -- doesn't work with sentence_transformers
# "microsoft/BioLinkBERT-base",

# 22.7M params
# https://huggingface.co/cross-encoder/ms-marco-MiniLM-L6-v2
# doesn't work with logits, so not between 0-1
"cross-encoder/ms-marco-MiniLM-L-6-v2",
]
SCORE_KEY = 'search_rerank_score'
MISSING_SCORE = -1000000.0

def __init__(self, model_name=None):
self.model_name = model_name
self.encoder = self._get_encoder(self.model_name)
self.model_name = model_name or self.default_model
self.encoder = None

def rerank( # pylint: disable=too-many-arguments
self, hits, txt, name_key='name', source_attr=None, should_convert_source_to_dict=True,
Expand All @@ -393,18 +395,54 @@ def _predict_scores(self, hits, txt, name_key, source_attr, should_convert_sourc
return scores_full

docs = [get(self._get_source(hit, source_attr, should_convert_source_to_dict), name_key) for hit in hits]
valid = []
valid_docs = []
for i, d in enumerate(docs):
if isinstance(d, str) and d.strip():
valid.append((i, d.strip()))
if not valid:
valid_docs.append((i, d.strip()))
if not valid_docs:
return scores_full
scores = self.encoder.predict([(txt, d) for _, d in valid])
for (i, _), s in zip(valid, scores):

scores = self._get_rerank_scores(txt, valid_docs)
for (i, _), s in zip(valid_docs, scores):
scores_full[i] = float(s)

return scores_full

def _get_rerank_scores(self, txt, docs):
if settings.EMBEDDING_SERVICE_URL:
return self._get_rerank_scores_from_service(txt, docs)
return self._get_rerank_scores_locally(txt, docs)

def _get_rerank_scores_from_service(self, txt, docs):
try:
import requests as req
response = req.post(
f'{settings.EMBEDDING_SERVICE_URL}/rerank',
headers={'Authorization': f'Bearer {settings.INFINITY_API_KEY}'},
json={
'model': self.model_name or self.default_model,
'query': txt,
'documents': [d for _, d in docs],
},
timeout=60
)
response.raise_for_status()
results = response.json()['results']
# results is a list of {index, relevance_score} sorted by index
return [r['relevance_score'] for r in sorted(results, key=lambda r: r['index'])]
except Exception as ex: # pylint: disable=broad-except
ERRBIT_LOGGER.log(ex)
return self._get_rerank_scores_locally(txt, docs)

def _get_rerank_scores_locally(self, txt, docs):
try:
if not self.encoder:
self.encoder = self._get_encoder()
return self.encoder.predict([(txt, d) for _, d in docs])
except Exception as ex: # pylint: disable=broad-except
ERRBIT_LOGGER.log(ex)
return [self.MISSING_SCORE] * len(docs)

def _assign_score(self, hits, scores, score_key, order_results):
score_key = score_key or self.SCORE_KEY
key_to_set = score_key
Expand All @@ -420,18 +458,8 @@ def _assign_score(self, hits, scores, score_key, order_results):
def _order(hits, key_to_order):
return sorted(hits, key=lambda hit: get(hit, key_to_order), reverse=True)

def _get_encoder(self, model_name):
if model_name and model_name != self.default_model:
return self._load_encoder(model_name)
return self._load_default_encoder()

@staticmethod
def _load_encoder(model_name):
return CrossEncoder(model_name, device="cpu", max_length=128)

@staticmethod
def _load_default_encoder():
return settings.ENCODER
def _get_encoder(self):
return CrossEncoder(self.model_name, device="cpu", max_length=128)

@staticmethod
def _get_source(data, source_attr, should_convert_source_to_dict):
Expand Down
174 changes: 174 additions & 0 deletions core/common/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -1621,3 +1621,177 @@ def test_multi_token_input_expands_each_known_token(self, mock_load, mock_resolv

terms = LexicalVariantDictionary.get_variant_terms('childhood leukaemia colour')
self.assertEqual(set(terms), {'leukemia', 'color'})


class VectorEmbedTest(OCLTestCase):
def setUp(self):
self.embedder = None

@override_settings(EMBEDDING_SERVICE_URL='http://embed-service:8008', INFINITY_API_KEY='test-key')
@patch('requests.post')
def test_embed_uses_service_when_url_configured(self, mock_post):
mock_post.return_value = Mock(
status_code=200,
json=Mock(return_value={'data': [{'embedding': [0.1, 0.2, 0.3]}]})
)
from core.common.search import VectorEmbed
result = VectorEmbed().embed('malaria')
self.assertEqual(result, [0.1, 0.2, 0.3])
mock_post.assert_called_once()
call_kwargs = mock_post.call_args
self.assertIn('/embeddings', call_kwargs[0][0])
self.assertEqual(call_kwargs[1]['headers']['Authorization'], 'Bearer test-key')
self.assertEqual(call_kwargs[1]['json']['input'], 'malaria')

@override_settings(EMBEDDING_SERVICE_URL='http://embed-service:8008', INFINITY_API_KEY='')
@patch('requests.post')
def test_embed_falls_back_to_local_on_service_error(self, mock_post):
mock_post.side_effect = Exception('connection refused')
from core.common.search import VectorEmbed
embedder = VectorEmbed()
with patch.object(embedder, '_get_embedding_locally', return_value=[0.4, 0.5]) as mock_local:
result = embedder.embed('diabetes')
mock_local.assert_called_once_with('diabetes')
self.assertEqual(result, [0.4, 0.5])

@override_settings(EMBEDDING_SERVICE_URL='')
def test_embed_uses_local_when_no_service_url(self):
from core.common.search import VectorEmbed
embedder = VectorEmbed()
with patch.object(embedder, '_get_embedding_locally', return_value=[0.1, 0.2]) as mock_local:
result = embedder.embed('hypertension')
mock_local.assert_called_once_with('hypertension')
self.assertEqual(result, [0.1, 0.2])

@override_settings(EMBEDDING_SERVICE_URL='http://embed-service:8008', INFINITY_API_KEY='')
@patch('requests.post')
def test_embed_uses_custom_model_name(self, mock_post):
mock_post.return_value = Mock(
status_code=200,
json=Mock(return_value={'data': [{'embedding': [0.9]}]})
)
from core.common.search import VectorEmbed
VectorEmbed(model_name='custom/model').embed('test')
self.assertEqual(mock_post.call_args[1]['json']['model'], 'custom/model')


class RerankerTest(OCLTestCase):
def _make_hit(self, name, use_search_meta=False):
if use_search_meta:
return {'name': name, 'search_meta': {}}
return {'name': name}

@override_settings(EMBEDDING_SERVICE_URL='http://embed-service:8008', INFINITY_API_KEY='test-key')
@patch('requests.post')
def test_rerank_uses_service_when_url_configured(self, mock_post):
mock_post.return_value = Mock(
status_code=200,
json=Mock(return_value={
'results': [
{'index': 0, 'relevance_score': 0.9},
{'index': 1, 'relevance_score': 0.3},
]
})
)
from core.common.search import Reranker
hits = [self._make_hit('malaria fever'), self._make_hit('diabetes')]
result = Reranker().rerank(hits, 'malaria', order_results=False)
self.assertEqual(result[0]['search_rerank_score'], 0.9)
self.assertEqual(result[1]['search_rerank_score'], 0.3)
mock_post.assert_called_once()
self.assertIn('/rerank', mock_post.call_args[0][0])

@override_settings(EMBEDDING_SERVICE_URL='http://embed-service:8008', INFINITY_API_KEY='test-key')
@patch('requests.post')
def test_rerank_falls_back_to_local_on_service_error(self, mock_post):
mock_post.side_effect = Exception('timeout')
from core.common.search import Reranker
reranker = Reranker()
hits = [self._make_hit('malaria')]
with patch.object(reranker, '_get_rerank_scores_locally', return_value=[0.7]) as mock_local:
reranker.rerank(hits, 'malaria', order_results=False)
mock_local.assert_called_once()

@override_settings(EMBEDDING_SERVICE_URL='http://embed-service:8008', INFINITY_API_KEY='test-key')
@patch('requests.post')
def test_rerank_returns_missing_score_when_both_fail(self, mock_post):
mock_post.side_effect = Exception('timeout')
from core.common.search import Reranker
reranker = Reranker()
with patch.object(reranker, '_get_rerank_scores_locally', return_value=[Reranker.MISSING_SCORE]):
hits = [self._make_hit('malaria')]
result = reranker.rerank(hits, 'malaria', order_results=False)
self.assertEqual(result[0]['search_rerank_score'], Reranker.MISSING_SCORE)

@override_settings(EMBEDDING_SERVICE_URL='')
def test_rerank_uses_local_when_no_service_url(self):
from core.common.search import Reranker
reranker = Reranker()
hits = [self._make_hit('malaria')]
with patch.object(reranker, '_get_rerank_scores_locally', return_value=[0.8]) as mock_local:
reranker.rerank(hits, 'malaria', order_results=False)
mock_local.assert_called_once()

def test_rerank_returns_empty_on_no_hits(self):
from core.common.search import Reranker
result = Reranker().rerank([], 'malaria')
self.assertEqual(result, [])

def test_rerank_returns_missing_score_on_blank_query(self):
from core.common.search import Reranker
hits = [self._make_hit('malaria')]
result = Reranker().rerank(hits, ' ', order_results=False)
self.assertEqual(result[0]['search_rerank_score'], Reranker.MISSING_SCORE)

def test_rerank_skips_hits_with_missing_name(self):
from core.common.search import Reranker
reranker = Reranker()
hits = [self._make_hit(''), self._make_hit('malaria')]
with patch.object(reranker, '_get_rerank_scores_locally', return_value=[0.9]) as mock_local:
result = reranker.rerank(hits, 'malaria', order_results=False)
# only one valid doc passed to scorer
self.assertEqual(len(mock_local.call_args[0][1]), 1)
self.assertEqual(result[0]['search_rerank_score'], Reranker.MISSING_SCORE)
self.assertEqual(result[1]['search_rerank_score'], 0.9)

@override_settings(EMBEDDING_SERVICE_URL='http://embed-service:8008', INFINITY_API_KEY='test-key')
@patch('requests.post')
def test_rerank_orders_results_by_score_descending(self, mock_post):
mock_post.return_value = Mock(
status_code=200,
json=Mock(return_value={
'results': [
{'index': 0, 'relevance_score': 0.2},
{'index': 1, 'relevance_score': 0.8},
]
})
)
from core.common.search import Reranker
hits = [self._make_hit('diabetes'), self._make_hit('malaria fever')]
result = Reranker().rerank(hits, 'malaria')
self.assertEqual(result[0]['name'], 'malaria fever')
self.assertEqual(result[1]['name'], 'diabetes')

@override_settings(EMBEDDING_SERVICE_URL='http://embed-service:8008', INFINITY_API_KEY='test-key')
@patch('requests.post')
def test_rerank_assigns_score_to_search_meta_when_present(self, mock_post):
mock_post.return_value = Mock(
status_code=200,
json=Mock(return_value={'results': [{'index': 0, 'relevance_score': 0.5}]})
)
from core.common.search import Reranker
hits = [self._make_hit('malaria', use_search_meta=True)]
result = Reranker().rerank(hits, 'malaria', order_results=False)
self.assertEqual(result[0]['search_meta']['search_rerank_score'], 0.5)
self.assertEqual(result[0]['search_meta']['search_normalized_score'], 50.0)

@override_settings(EMBEDDING_SERVICE_URL='http://embed-service:8008', INFINITY_API_KEY='test-key')
@patch('requests.post')
def test_rerank_uses_custom_model(self, mock_post):
mock_post.return_value = Mock(
status_code=200,
json=Mock(return_value={'results': [{'index': 0, 'relevance_score': 0.6}]})
)
from core.common.search import Reranker
Reranker(model_name='custom/reranker').rerank([self._make_hit('malaria')], 'malaria')
self.assertEqual(mock_post.call_args[1]['json']['model'], 'custom/reranker')
12 changes: 0 additions & 12 deletions core/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -918,15 +918,3 @@ def format_url_for_search(url):

def clean_term(term):
return term.lower().replace(' ', '').replace('-', '').replace('_', '')


def get_embeddings(txt):
from core.toggles.models import Toggle
if not Toggle.get('SEMANTIC_SEARCH_TOGGLE') or settings.ENV == 'ci':
return None

model = settings.LM
if not model:
from sentence_transformers import SentenceTransformer
model = SentenceTransformer(settings.LM_MODEL_NAME)
return model.encode(str(txt))
8 changes: 5 additions & 3 deletions core/concepts/documents.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
from django_elasticsearch_dsl.registries import registry
from pydash import compact, get

from core.common.utils import jsonify_safe, flatten_dict, get_embeddings, drop_version
from core.common.search import VectorEmbed
from core.common.utils import jsonify_safe, flatten_dict, drop_version
from core.concepts.models import Concept


Expand Down Expand Up @@ -244,14 +245,15 @@ def prepare(self, instance):
data['_synonyms'] = data['synonyms']

if instance.parent.has_semantic_match_algorithm:
_embedder = VectorEmbed()
data['_embeddings'] = {
'vector': get_embeddings(name),
'vector': _embedder.embed(name),
'type': get(preferred_locale, 'type'),
'locale': get(preferred_locale, 'locale')
}
data['_synonyms_embeddings'] = [
{
'vector': get_embeddings(s.name),
'vector': _embedder.embed(s.name),
'type': get(s, 'type'),
'locale': get(s, 'locale')
} for s in synonyms
Expand Down
Loading