-
Notifications
You must be signed in to change notification settings - Fork 40
OpenConceptLab/ocl_online#116 | Using Infinity for embedding and reranking with fallback to inline load of model in api/indexing services #878
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -9,6 +9,7 @@ | |||||||||||||||||||||||||
| from pydash import compact, get, has, set_ | ||||||||||||||||||||||||||
| from sentence_transformers import CrossEncoder | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| from core.common import ERRBIT_LOGGER | ||||||||||||||||||||||||||
| from core.common.constants import ES_REQUEST_TIMEOUT | ||||||||||||||||||||||||||
| from core.common.utils import is_url_encoded_string | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
|
|
@@ -336,42 +337,43 @@ def __get_response(self, exact_count=True, load_fields=False): | |||||||||||||||||||||||||
| return self._dsl_search, None, total | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| class VectorEmbed: | ||||||||||||||||||||||||||
| def __init__(self, model_name=None): | ||||||||||||||||||||||||||
| self.model_name = model_name or settings.LM_MODEL_NAME | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| def embed(self, txt): | ||||||||||||||||||||||||||
| if settings.EMBEDDING_SERVICE_URL: | ||||||||||||||||||||||||||
| return self._get_embedding_from_service(txt) | ||||||||||||||||||||||||||
| return self._get_embedding_locally(txt) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| def _get_embedding_from_service(self, txt): | ||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||
| import requests as req | ||||||||||||||||||||||||||
| response = req.post( | ||||||||||||||||||||||||||
| f'{settings.EMBEDDING_SERVICE_URL}/embeddings', | ||||||||||||||||||||||||||
| headers={'Authorization': f'Bearer {settings.INFINITY_API_KEY}'}, | ||||||||||||||||||||||||||
| json={'model': self.model_name, 'input': str(txt)}, | ||||||||||||||||||||||||||
| timeout=10 | ||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||
| response.raise_for_status() | ||||||||||||||||||||||||||
| return response.json()['data'][0]['embedding'] | ||||||||||||||||||||||||||
| except Exception as ex: # pylint: disable=broad-except | ||||||||||||||||||||||||||
| ERRBIT_LOGGER.log(ex) | ||||||||||||||||||||||||||
| return self._get_embedding_locally(txt) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| def _get_embedding_locally(self, txt): | ||||||||||||||||||||||||||
| from sentence_transformers import SentenceTransformer | ||||||||||||||||||||||||||
| model = SentenceTransformer(self.model_name) | ||||||||||||||||||||||||||
| return list(model.encode(str(txt))) | ||||||||||||||||||||||||||
|
Comment on lines
+364
to
+367
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. #1 + #3 — cache the local model and return native floats. The fallback currently reloads a ~400MB
Suggested change
|
||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| class Reranker: | ||||||||||||||||||||||||||
| ENCODERS = [ | ||||||||||||||||||||||||||
| # Best and Fastest overall lightweight medical reranker | ||||||||||||||||||||||||||
| # Size: ~110M | ||||||||||||||||||||||||||
| # Speed: similar to MiniLM CrossEncoder | ||||||||||||||||||||||||||
| # Training: includes clinical, medical, question-answering datasets | ||||||||||||||||||||||||||
| # Output: positive similarity scores (not raw logits!) | ||||||||||||||||||||||||||
| # 0.6B params | ||||||||||||||||||||||||||
| # https://huggingface.co/BAAI/bge-reranker-v2-m3 | ||||||||||||||||||||||||||
| "BAAI/bge-reranker-v2-m3", | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| # Model: jinhybr/OA-MedBERT-cross-encoder or similar | ||||||||||||||||||||||||||
| # Size: ~110M | ||||||||||||||||||||||||||
| # Domain: PubMed abstracts, biomedical QA | ||||||||||||||||||||||||||
| # Type: binary classifier (logits) | ||||||||||||||||||||||||||
| # Not huggin face model -- ??? | ||||||||||||||||||||||||||
| # "jinhybr/OA-MedBERT-cross-encoder", | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| # Model: microsoft/BioLinkBERT-base | ||||||||||||||||||||||||||
| # Type: CrossEncoder | ||||||||||||||||||||||||||
| # Size: ~120M | ||||||||||||||||||||||||||
| # Domain: UMLS, PubMed, MeSH, SNOMED (closest to OCL) | ||||||||||||||||||||||||||
| # Not huggin face model -- doesn't work with sentence_transformers | ||||||||||||||||||||||||||
| # "microsoft/BioLinkBERT-base", | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| # 22.7M params | ||||||||||||||||||||||||||
| # https://huggingface.co/cross-encoder/ms-marco-MiniLM-L6-v2 | ||||||||||||||||||||||||||
| # doesn't work with logits, so not between 0-1 | ||||||||||||||||||||||||||
| "cross-encoder/ms-marco-MiniLM-L-6-v2", | ||||||||||||||||||||||||||
| ] | ||||||||||||||||||||||||||
| SCORE_KEY = 'search_rerank_score' | ||||||||||||||||||||||||||
| MISSING_SCORE = -1000000.0 | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| def __init__(self, model_name=None): | ||||||||||||||||||||||||||
| self.model_name = model_name | ||||||||||||||||||||||||||
| self.encoder = self._get_encoder(self.model_name) | ||||||||||||||||||||||||||
| self.model_name = model_name or self.default_model | ||||||||||||||||||||||||||
| self.encoder = None | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| def rerank( # pylint: disable=too-many-arguments | ||||||||||||||||||||||||||
| self, hits, txt, name_key='name', source_attr=None, should_convert_source_to_dict=True, | ||||||||||||||||||||||||||
|
|
@@ -393,18 +395,54 @@ def _predict_scores(self, hits, txt, name_key, source_attr, should_convert_sourc | |||||||||||||||||||||||||
| return scores_full | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| docs = [get(self._get_source(hit, source_attr, should_convert_source_to_dict), name_key) for hit in hits] | ||||||||||||||||||||||||||
| valid = [] | ||||||||||||||||||||||||||
| valid_docs = [] | ||||||||||||||||||||||||||
| for i, d in enumerate(docs): | ||||||||||||||||||||||||||
| if isinstance(d, str) and d.strip(): | ||||||||||||||||||||||||||
| valid.append((i, d.strip())) | ||||||||||||||||||||||||||
| if not valid: | ||||||||||||||||||||||||||
| valid_docs.append((i, d.strip())) | ||||||||||||||||||||||||||
| if not valid_docs: | ||||||||||||||||||||||||||
| return scores_full | ||||||||||||||||||||||||||
| scores = self.encoder.predict([(txt, d) for _, d in valid]) | ||||||||||||||||||||||||||
| for (i, _), s in zip(valid, scores): | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| scores = self._get_rerank_scores(txt, valid_docs) | ||||||||||||||||||||||||||
| for (i, _), s in zip(valid_docs, scores): | ||||||||||||||||||||||||||
| scores_full[i] = float(s) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| return scores_full | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| def _get_rerank_scores(self, txt, docs): | ||||||||||||||||||||||||||
| if settings.EMBEDDING_SERVICE_URL: | ||||||||||||||||||||||||||
| return self._get_rerank_scores_from_service(txt, docs) | ||||||||||||||||||||||||||
| return self._get_rerank_scores_locally(txt, docs) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| def _get_rerank_scores_from_service(self, txt, docs): | ||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||
| import requests as req | ||||||||||||||||||||||||||
| response = req.post( | ||||||||||||||||||||||||||
| f'{settings.EMBEDDING_SERVICE_URL}/rerank', | ||||||||||||||||||||||||||
| headers={'Authorization': f'Bearer {settings.INFINITY_API_KEY}'}, | ||||||||||||||||||||||||||
| json={ | ||||||||||||||||||||||||||
| 'model': self.model_name or self.default_model, | ||||||||||||||||||||||||||
| 'query': txt, | ||||||||||||||||||||||||||
| 'documents': [d for _, d in docs], | ||||||||||||||||||||||||||
| }, | ||||||||||||||||||||||||||
| timeout=60 | ||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||
| response.raise_for_status() | ||||||||||||||||||||||||||
| results = response.json()['results'] | ||||||||||||||||||||||||||
| # results is a list of {index, relevance_score} sorted by index | ||||||||||||||||||||||||||
| return [r['relevance_score'] for r in sorted(results, key=lambda r: r['index'])] | ||||||||||||||||||||||||||
| except Exception as ex: # pylint: disable=broad-except | ||||||||||||||||||||||||||
| ERRBIT_LOGGER.log(ex) | ||||||||||||||||||||||||||
| return self._get_rerank_scores_locally(txt, docs) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| def _get_rerank_scores_locally(self, txt, docs): | ||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||
| if not self.encoder: | ||||||||||||||||||||||||||
| self.encoder = self._get_encoder() | ||||||||||||||||||||||||||
| return self.encoder.predict([(txt, d) for _, d in docs]) | ||||||||||||||||||||||||||
| except Exception as ex: # pylint: disable=broad-except | ||||||||||||||||||||||||||
| ERRBIT_LOGGER.log(ex) | ||||||||||||||||||||||||||
| return [self.MISSING_SCORE] * len(docs) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| def _assign_score(self, hits, scores, score_key, order_results): | ||||||||||||||||||||||||||
| score_key = score_key or self.SCORE_KEY | ||||||||||||||||||||||||||
| key_to_set = score_key | ||||||||||||||||||||||||||
|
|
@@ -420,18 +458,8 @@ def _assign_score(self, hits, scores, score_key, order_results): | |||||||||||||||||||||||||
| def _order(hits, key_to_order): | ||||||||||||||||||||||||||
| return sorted(hits, key=lambda hit: get(hit, key_to_order), reverse=True) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| def _get_encoder(self, model_name): | ||||||||||||||||||||||||||
| if model_name and model_name != self.default_model: | ||||||||||||||||||||||||||
| return self._load_encoder(model_name) | ||||||||||||||||||||||||||
| return self._load_default_encoder() | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| @staticmethod | ||||||||||||||||||||||||||
| def _load_encoder(model_name): | ||||||||||||||||||||||||||
| return CrossEncoder(model_name, device="cpu", max_length=128) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| @staticmethod | ||||||||||||||||||||||||||
| def _load_default_encoder(): | ||||||||||||||||||||||||||
| return settings.ENCODER | ||||||||||||||||||||||||||
| def _get_encoder(self): | ||||||||||||||||||||||||||
| return CrossEncoder(self.model_name, device="cpu", max_length=128) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| @staticmethod | ||||||||||||||||||||||||||
| def _get_source(data, source_attr, should_convert_source_to_dict): | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit (b):
import requests as reqhere and again in_get_rerank_scores_from_service(~line 418).requestsis already a top-level dependency — hoist a singleimport requestsinto the module import block and drop the per-call aliased imports.