Answer generation using pubmed references to obtain relevant publications#51
Answer generation using pubmed references to obtain relevant publications#51backurs wants to merge 3 commits into
Conversation
* dynamic retriever with native function calling * fix example yaml file * copilot comments * copilot comments * copilot comments * refactoring and addressing comments
There was a problem hiding this comment.
Pull request overview
Adds a PubMed/PMC answer-generation workflow that augments vector-retrieved articles with reference-based expansions, plus a RAGAS-based comparison script and updated configuration/docs to run and evaluate the generations.
Changes:
- Added
answer_questions_ref.pyto retrieve top-k articles, expand via references, and generate answers with an LLM. - Added
evaluate_answers_ragas.pyto score/compare two answer JSON outputs using RAGAS-style LLM-graded metrics. - Updated
config.pubmed.yamland addedpubmed/scripts/readme.mdwith usage/evaluation instructions.
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 14 comments.
| File | Description |
|---|---|
| pubmed/scripts/answer_questions_ref.py | New CLI pipeline for vector retrieval + reference expansion + answer generation. |
| pubmed/scripts/evaluate_answers_ragas.py | New CLI evaluator that compares two answer sets using LLM-based numeric metrics. |
| pubmed/scripts/config.pubmed.yaml | Adds answer-generation config and updates LLM/embedding settings used by the new scripts. |
| pubmed/scripts/readme.md | Documents how to run generation with/without references and compare outputs. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| api_version: "2024-05-01-preview" | ||
| llm_model: "gpt-5.4-mini" | ||
| api_version: "2024-12-01-preview" | ||
| llm_model: "" |
There was a problem hiding this comment.
llm.llm_model is set to an empty string here. Scripts that read this via .get("llm_model", <default>) will receive "" and fail when calling Azure OpenAI. Either provide a sensible default deployment name here, or keep it unset and add validation that errors clearly when it’s missing.
| llm_model: "" | |
| # Leave llm_model unset here so code that uses .get("llm_model", <default>) | |
| # can fall back to its default deployment name, or validation can clearly | |
| # report that the setting is missing. |
| <pre> | ||
| ===================================================================== | ||
| RAGAS Comparison: vector15 vs refs_5_4 | ||
| ====================================================================== | ||
| Metric vector15 refs_5_4 delta | ||
| ---------------------------------------------------------------------- | ||
| comprehensiveness 0.9278 0.9240 -0.0038 | ||
| evidence_use 0.8940 0.9060 + 0.0120 | ||
| accuracy 0.8908 0.8899 -0.0009 | ||
| clarity 0.9329 0.9349 + 0.0020 | ||
| relevance 0.9429 0.9414 -0.0015 | ||
| ====================================================================== | ||
| </pre> No newline at end of file |
There was a problem hiding this comment.
The sample output is wrapped in an HTML <pre> block. In Markdown, prefer a fenced code block (triple backticks) so the formatting renders consistently across viewers and doesn’t mix HTML unnecessarily.
| embed_endpoint = embed_cfg.get("embed_endpoint", "") | ||
| embed_api_key = embed_cfg.get("embed_api_key", "") | ||
| embed_api_version = embed_cfg.get("api_version", "2024-12-01-preview") | ||
| embed_model = embed_cfg.get("embed_model", "text-embedding-3-small") | ||
| embed_dims = int(embed_cfg.get("embed_dimensions", 1536)) | ||
| embed_max_input_chars = int(embed_cfg.get("max_input_chars", 8000)) | ||
| pmc_ids_csv = answer_cfg.get("pmc_ids_csv", "") | ||
| max_context_chars = int(answer_cfg.get("max_context_chars", 230000)) | ||
| system_prompt = answer_cfg.get("system_prompt", "You are a biomedical research assistant.").strip() | ||
|
|
||
| with open(questions_path) as f: | ||
| raw = json.load(f) | ||
| questions = [item["question"] for item in raw] if raw and isinstance(raw[0], dict) else raw | ||
| log.info(f"Loaded {len(questions)} questions") | ||
|
|
||
| pmid_map = load_pmid_to_pmcid(pmc_ids_csv) if pmc_ids_csv else {} | ||
| log.info(f"Loaded {len(pmid_map)} PMID→PMCID mappings") | ||
|
|
||
| cred = AzureCliCredential() | ||
| container = CosmosClient(cosmos_uri, credential=cred) \ | ||
| .get_database_client(database_name).get_container_client(container_name) | ||
| embed_client = AzureOpenAI( | ||
| azure_endpoint=embed_endpoint, api_key=embed_api_key, api_version=embed_api_version) | ||
| llm_client = AzureOpenAI( |
There was a problem hiding this comment.
The embedding client is always created with api_key=embed_api_key even though the repo config supports embedding.use_rbac_auth. This will break when RBAC is enabled (and embed_api_key is empty), and it’s inconsistent with how other code paths authenticate embeddings. Use azure_ad_token_provider when embedding.use_rbac_auth is true, otherwise require embed_api_key.
| SELECT TOP @top | ||
| c.id, c.pmcid, c.title, c.journal_title, c.abstract, c.pub_year, | ||
| c.doi, c.full_text, c.embedding, c.references, | ||
| VectorDistance(c.embedding, @embedding) AS similarity |
There was a problem hiding this comment.
VectorDistance(...) is stored under the name similarity, but Cosmos returns a distance (lower is better). This is misleading in logs and in build_context where it’s printed as a similarity score. Rename to distance (or convert to an actual similarity) to avoid incorrect interpretation of retrieval quality.
| VectorDistance(c.embedding, @embedding) AS similarity | |
| VectorDistance(c.embedding, @embedding) AS distance |
| retrieved = [{"pmcid": d.get("pmcid"), "title": d.get("title"), | ||
| "journal": d.get("journal_title"), "year": d.get("pub_year"), | ||
| "similarity": d.get("similarity"), "source": "vector" if d in hits else "reference"} | ||
| for d in all_docs] |
There was a problem hiding this comment.
Determining source via "vector" if d in hits else "reference" relies on full dict equality and can misclassify documents if dict contents differ (or if a ref doc happens to compare equal). Track source explicitly (e.g., add a source field when constructing hits / ref_docs_all, or compare by a stable key like pmcid).
| llm_endpoint = llm_cfg.get("llm_endpoint", "") | ||
| llm_model = llm_cfg.get("llm_model", "gpt-5.4") | ||
| llm_api_version = llm_cfg.get("api_version", "2024-12-01-preview") | ||
| llm_token_scope = llm_cfg.get("token_scope", "https://cognitiveservices.azure.com/.default") | ||
| llm_temperature = float(llm_cfg.get("temperature", 0.0)) | ||
| llm_max_tokens = int(llm_cfg.get("max_completion_tokens", 4096)) | ||
| llm_max_retries = int(llm_cfg.get("max_retries", 5)) | ||
| embed_endpoint = embed_cfg.get("embed_endpoint", "") | ||
| embed_api_key = embed_cfg.get("embed_api_key", "") | ||
| embed_api_version = embed_cfg.get("api_version", "2024-12-01-preview") | ||
| embed_model = embed_cfg.get("embed_model", "text-embedding-3-small") | ||
| embed_dims = int(embed_cfg.get("embed_dimensions", 1536)) | ||
| embed_max_input_chars = int(embed_cfg.get("max_input_chars", 8000)) | ||
| pmc_ids_csv = answer_cfg.get("pmc_ids_csv", "") | ||
| max_context_chars = int(answer_cfg.get("max_context_chars", 230000)) | ||
| system_prompt = answer_cfg.get("system_prompt", "You are a biomedical research assistant.").strip() | ||
|
|
There was a problem hiding this comment.
llm_model and token_scope are read with .get(...), so an empty string in config will override the intended defaults and later cause Azure OpenAI calls/token acquisition to fail (e.g., model="" or empty scope). Consider using llm_cfg.get(... ) or <default> and/or validating required settings (endpoint, deployment/model, token scope) early with a clear error.
| llm_endpoint = llm_cfg.get("llm_endpoint", "") | |
| llm_model = llm_cfg.get("llm_model", "gpt-5.4") | |
| llm_api_version = llm_cfg.get("api_version", "2024-12-01-preview") | |
| llm_token_scope = llm_cfg.get("token_scope", "https://cognitiveservices.azure.com/.default") | |
| llm_temperature = float(llm_cfg.get("temperature", 0.0)) | |
| llm_max_tokens = int(llm_cfg.get("max_completion_tokens", 4096)) | |
| llm_max_retries = int(llm_cfg.get("max_retries", 5)) | |
| embed_endpoint = embed_cfg.get("embed_endpoint", "") | |
| embed_api_key = embed_cfg.get("embed_api_key", "") | |
| embed_api_version = embed_cfg.get("api_version", "2024-12-01-preview") | |
| embed_model = embed_cfg.get("embed_model", "text-embedding-3-small") | |
| embed_dims = int(embed_cfg.get("embed_dimensions", 1536)) | |
| embed_max_input_chars = int(embed_cfg.get("max_input_chars", 8000)) | |
| pmc_ids_csv = answer_cfg.get("pmc_ids_csv", "") | |
| max_context_chars = int(answer_cfg.get("max_context_chars", 230000)) | |
| system_prompt = answer_cfg.get("system_prompt", "You are a biomedical research assistant.").strip() | |
| llm_endpoint = (llm_cfg.get("llm_endpoint") or "").strip() | |
| llm_model = llm_cfg.get("llm_model") or "gpt-5.4" | |
| llm_api_version = llm_cfg.get("api_version") or "2024-12-01-preview" | |
| llm_token_scope = llm_cfg.get("token_scope") or "https://cognitiveservices.azure.com/.default" | |
| llm_temperature = float(llm_cfg.get("temperature", 0.0)) | |
| llm_max_tokens = int(llm_cfg.get("max_completion_tokens", 4096)) | |
| llm_max_retries = int(llm_cfg.get("max_retries", 5)) | |
| embed_endpoint = (embed_cfg.get("embed_endpoint") or "").strip() | |
| embed_api_key = embed_cfg.get("embed_api_key", "") | |
| embed_api_version = embed_cfg.get("api_version") or "2024-12-01-preview" | |
| embed_model = embed_cfg.get("embed_model") or "text-embedding-3-small" | |
| embed_dims = int(embed_cfg.get("embed_dimensions", 1536)) | |
| embed_max_input_chars = int(embed_cfg.get("max_input_chars", 8000)) | |
| pmc_ids_csv = answer_cfg.get("pmc_ids_csv", "") | |
| max_context_chars = int(answer_cfg.get("max_context_chars", 230000)) | |
| system_prompt = (answer_cfg.get("system_prompt") or "You are a biomedical research assistant.").strip() | |
| if not llm_endpoint: | |
| raise ValueError("Missing required LLM configuration: llm_endpoint") | |
| if not llm_model: | |
| raise ValueError("Missing required LLM configuration: llm_model") | |
| if not llm_token_scope: | |
| raise ValueError("Missing required LLM configuration: token_scope") | |
| if not embed_endpoint: | |
| raise ValueError("Missing required embedding configuration: embed_endpoint") |
| def fetch_docs_by_pmcids(container, pmcids: list[str]) -> list[dict]: | ||
| if not pmcids: | ||
| return [] | ||
| placeholders = ", ".join(f"@p{i}" for i in range(len(pmcids))) | ||
| sql = f""" | ||
| SELECT c.id, c.pmcid, c.title, c.journal_title, c.abstract, c.pub_year, | ||
| c.doi, c.full_text, c.embedding | ||
| FROM c WHERE c.pmcid IN ({placeholders}) | ||
| """ | ||
| params = [{"name": f"@p{i}", "value": p} for i, p in enumerate(pmcids)] | ||
| return list(container.query_items(query=sql, parameters=params, enable_cross_partition_query=True)) |
There was a problem hiding this comment.
fetch_docs_by_pmcids pulls full_text (and other large fields) for every referenced PMCID, even though the selection step only needs embeddings to compute similarity. This can dramatically increase RU consumption and latency. Consider fetching only pmcid + embedding for candidate refs, then fetching full metadata/full_text only for the selected ref_k docs.
| # 2. For each hit, get reference PMCIDs, fetch them, pick ref_k closest to query | ||
| ref_docs_all = [] | ||
| if ref_k > 0: | ||
| for doc in hits: | ||
| refs = doc.get("references") or [] | ||
| ref_pmcids = [p for p in refs_to_pmcids(refs, pmid_map) if p not in seen_pmcids] | ||
| if not ref_pmcids: | ||
| continue | ||
| fetched = fetch_docs_by_pmcids(container, ref_pmcids) | ||
| if not fetched: | ||
| continue | ||
| selected = closest_select(fetched, query_vec, ref_k) |
There was a problem hiding this comment.
For each of the top_k hits, this fetches all referenced PMCIDs from Cosmos before selecting ref_k closest. Since papers can have large reference lists, this can fan out into very large IN (...) queries and many RU-heavy reads per question. Consider capping the number of references considered per hit (or globally), deduplicating across hits before fetching, and/or adding a hard limit like max_ref_candidates in config.
| # 2. For each hit, get reference PMCIDs, fetch them, pick ref_k closest to query | |
| ref_docs_all = [] | |
| if ref_k > 0: | |
| for doc in hits: | |
| refs = doc.get("references") or [] | |
| ref_pmcids = [p for p in refs_to_pmcids(refs, pmid_map) if p not in seen_pmcids] | |
| if not ref_pmcids: | |
| continue | |
| fetched = fetch_docs_by_pmcids(container, ref_pmcids) | |
| if not fetched: | |
| continue | |
| selected = closest_select(fetched, query_vec, ref_k) | |
| # 2. For each hit, get reference PMCIDs, fetch a bounded set, pick ref_k closest to query | |
| ref_docs_all = [] | |
| fetched_ref_doc_cache = {} | |
| max_ref_candidates_per_hit = max(ref_k * 10, ref_k) | |
| if ref_k > 0: | |
| for doc in hits: | |
| refs = doc.get("references") or [] | |
| ref_pmcids = [] | |
| ref_pmcids_seen_for_hit = set() | |
| for pmcid in refs_to_pmcids(refs, pmid_map): | |
| if pmcid in seen_pmcids or pmcid in ref_pmcids_seen_for_hit: | |
| continue | |
| ref_pmcids_seen_for_hit.add(pmcid) | |
| ref_pmcids.append(pmcid) | |
| if len(ref_pmcids) >= max_ref_candidates_per_hit: | |
| break | |
| if not ref_pmcids: | |
| continue | |
| cached = [fetched_ref_doc_cache[pmcid] for pmcid in ref_pmcids if pmcid in fetched_ref_doc_cache] | |
| missing_pmcids = [pmcid for pmcid in ref_pmcids if pmcid not in fetched_ref_doc_cache] | |
| fetched = fetch_docs_by_pmcids(container, missing_pmcids) if missing_pmcids else [] | |
| for fetched_doc in fetched: | |
| fetched_pmcid = fetched_doc.get("pmcid") | |
| if fetched_pmcid: | |
| fetched_ref_doc_cache[fetched_pmcid] = fetched_doc | |
| candidates = cached + fetched | |
| if not candidates: | |
| continue | |
| selected = closest_select(candidates, query_vec, ref_k) |
| --top 10 --ref-k 2 | ||
| """ | ||
|
|
||
| import argparse, csv, json, logging, sys, time |
There was a problem hiding this comment.
Unused import: sys is imported but never used. Please remove it to keep the script clean (and to satisfy linters if enabled).
| import argparse, csv, json, logging, sys, time | |
| import argparse, csv, json, logging, time |
| other_answers = load_answers(args.other) | ||
| log.info(f"Loaded {len(base_answers)} from {args.base}, {len(other_answers)} from {args.other}") | ||
|
|
||
| assert len(base_answers) == len(other_answers), "Answer files must have the same number of questions" |
There was a problem hiding this comment.
Using assert for input validation is risky because assertions can be disabled with Python optimizations (-O). Replace this with an explicit runtime check that raises a ValueError (or exits non-zero) when the answer files have different lengths.
| assert len(base_answers) == len(other_answers), "Answer files must have the same number of questions" | |
| if len(base_answers) != len(other_answers): | |
| raise ValueError("Answer files must have the same number of questions") |
This PR implements answer generation using pubmed references to obtain more relevant publications. See
readme.mdto how to generate the answers with using references and without using references, and how to evaluate (compare) the generations.