Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,11 @@ df = pd.read_sql(query, engine)

#### Django REST
- The email and password are set in `server/api/management/commands/createsu.py`
- Backend tests can be run using `pytest` by running the below command inside the running backend container:

```
docker compose exec backend pytest api/ -v
```

## API Documentation

Expand Down
27 changes: 27 additions & 0 deletions server/api/apps.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,30 @@
class ApiConfig(AppConfig):
default_auto_field = 'django.db.models.BigAutoField'
name = 'api'

def ready(self):
import os
import sys

# ready() runs in every Django process: migrate, test, shell, runserver, etc.
# Only preload the model when we're actually going to serve requests.
# Dev (docker-compose.yml) runs `manage.py runserver 0.0.0.0:8000`.
# Prod (Dockerfile.prod CMD) runs `manage.py runserver 0.0.0.0:8000 --noreload`.
# entrypoint.prod.sh also runs migrate, createsu, and populatedb before exec'ing
# runserver — the guard below correctly skips model loading for those commands too.
if sys.argv[1:2] != ['runserver']:
return

# Dev's autoreloader spawns two processes: a parent file-watcher and a child
# server. ready() runs in both, but only the child (RUN_MAIN=true) serves
# requests. Skip the parent to avoid loading the model twice on each file change.
# Prod uses --noreload so RUN_MAIN is never set; 'noreload' in sys.argv handles that case.
if os.environ.get('RUN_MAIN') != 'true' and '--noreload' not in sys.argv:
return

# Note: paraphrase-MiniLM-L6-v2 (~80MB) is downloaded from HuggingFace on first
# use and cached to ~/.cache/torch/sentence_transformers/ inside the container.
# That cache is ephemeral — every container rebuild re-downloads the model unless
# a volume is mounted at that path.
from .services.sentencetTransformer_model import TransformerModel
TransformerModel.get_instance()
Comment on lines +8 to +33
Copy link

Copilot AI Mar 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ApiConfig.ready() will only run if this AppConfig is actually used by Django. Right now INSTALLED_APPS appears to include just "api" (not "api.apps.ApiConfig"), and api/__init__.py doesn’t set a default config, so this preload hook may never execute. Consider updating INSTALLED_APPS to reference api.apps.ApiConfig (or otherwise ensuring this config is selected) so the model is preloaded as intended.

Copilot uses AI. Check for mistakes.
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The model preloads as intended because Django ≥ 3.2 auto discovers AppConfig subclasses in apps.py

Comment on lines +8 to +33
Copy link

Copilot AI Mar 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Calling TransformerModel.get_instance() unconditionally in ready() will run for every Django startup context (tests, migrations, management commands, autoreload) and can trigger a large model download/init even when no web traffic will be served. Consider gating this preload behind an explicit env flag (or limiting it to the web server entrypoint) to avoid slowing/fragilizing CI and one-off management commands.

Copilot uses AI. Check for mistakes.
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a guard to only preload the model when we're actually going to serve requests

162 changes: 113 additions & 49 deletions server/api/services/embedding_services.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import logging
from statistics import median

# Django filter() only does ADD logic
from django.db.models import Q
from pgvector.django import L2Distance

Expand All @@ -11,18 +12,17 @@

logger = logging.getLogger(__name__)

def get_closest_embeddings(
user, message_data, document_name=None, guid=None, num_results=10
):

def build_query(user, embedding_vector, document_name=None, guid=None, num_results=10):
"""
Find the closest embeddings to a given message for a specific user.
Build an unevaluated QuerySet for the closest embeddings.

Parameters
----------
user : User
The user whose uploaded documents will be searched
message_data : str
The input message to find similar embeddings for
embedding_vector : array-like
Pre-computed embedding vector to compare against
document_name : str, optional
Filter results to a specific document name
guid : str, optional
Expand All @@ -32,59 +32,52 @@ def get_closest_embeddings(

Returns
-------
list[dict]
List of dictionaries containing embedding results with keys:
- name: document name
- text: embedded text content
- page_number: page number in source document
- chunk_number: chunk number within the document
- distance: L2 distance from query embedding
- file_id: GUID of the source file
QuerySet
Unevaluated Django QuerySet ordered by L2 distance, sliced to num_results
"""

encoding_start = time.time()
transformerModel = TransformerModel.get_instance().model
embedding_message = transformerModel.encode(message_data)
encoding_time = time.time() - encoding_start

db_query_start = time.time()

# Django QuerySets are lazily evaluated
if user.is_authenticated:
# User sees their own files + files uploaded by superusers
closest_embeddings_query = (
Embeddings.objects.filter(
Q(upload_file__uploaded_by=user) | Q(upload_file__uploaded_by__is_superuser=True)
)
.annotate(
distance=L2Distance("embedding_sentence_transformers", embedding_message)
)
.order_by("distance")
queryset = Embeddings.objects.filter(
Q(upload_file__uploaded_by=user) | Q(upload_file__uploaded_by__is_superuser=True)
)
else:
# Unauthenticated users only see superuser-uploaded files
closest_embeddings_query = (
Embeddings.objects.filter(upload_file__uploaded_by__is_superuser=True)
.annotate(
distance=L2Distance("embedding_sentence_transformers", embedding_message)
)
.order_by("distance")
)
queryset = Embeddings.objects.filter(upload_file__uploaded_by__is_superuser=True)

queryset = (
queryset
.annotate(distance=L2Distance("embedding_sentence_transformers", embedding_vector))
.order_by("distance")
)

# Filtering to a document GUID takes precedence over a document name
if guid:
closest_embeddings_query = closest_embeddings_query.filter(
upload_file__guid=guid
)
queryset = queryset.filter(upload_file__guid=guid)
elif document_name:
closest_embeddings_query = closest_embeddings_query.filter(name=document_name)
queryset = queryset.filter(name=document_name)

# Slicing is equivalent to SQL's LIMIT clause
closest_embeddings_query = closest_embeddings_query[:num_results]
return queryset[:num_results]
Comment on lines 16 to +61
Copy link

Copilot AI Mar 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

build_query() introduces/relocates important filtering + precedence logic (authenticated vs unauthenticated visibility; guid-over-document_name; LIMIT slicing), but the new tests only cover evaluate_query and log_usage. Add unit/integration tests covering build_query behavior (e.g., guid precedence and the authenticated/unauthenticated queryset filters) to prevent regressions in access control and filtering.

Copilot uses AI. Check for mistakes.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Building on Copilot's comment, the specifics of the QuerySet object's structure aren't publicly documented. To inspect the QuerySets, we should actually execute them.

There's a couple ways we handle DB access for these tests. We could use [pytest-django's ``@pytest.mark.django_db](https://pytest-django.readthedocs.io/en/latest/database.html), which wraps the test in a transaction the rolls back automatically afterwards. Django also has a built-in django.test.TestCase`, which does a similar thing.

Copy link
Collaborator Author

@sahilds1 sahilds1 Mar 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for sharing the docs references -- I added tests for build_query and didn't have to access the database because I was able to inspect which methods and arguments were called on the model ("Embeddings")



def evaluate_query(queryset):
"""
Evaluate a QuerySet and return a list of result dicts.

Parameters
----------
queryset : iterable
Iterable of Embeddings objects (or any objects with the expected attributes)

Returns
-------
list[dict]
List of dicts with keys: name, text, page_number, chunk_number, distance, file_id
"""
# Iterating evaluates the QuerySet and hits the database
# TODO: Research improving the query evaluation performance
results = [
return [
{
"name": obj.name,
"text": obj.text,
Expand All @@ -93,13 +86,36 @@ def get_closest_embeddings(
"distance": obj.distance,
"file_id": obj.upload_file.guid if obj.upload_file else None,
}
for obj in closest_embeddings_query
for obj in queryset
]

db_query_time = time.time() - db_query_start

def log_usage(
results, message_data, user, guid, document_name, num_results, encoding_time, db_query_time
):
"""
Create a SemanticSearchUsage record. Swallows exceptions so search isn't interrupted.

Parameters
----------
results : list[dict]
The search results, each containing a "distance" key
message_data : str
The original search query text
user : User
The user who performed the search
guid : str or None
Document GUID filter used in the search
document_name : str or None
Document name filter used in the search
num_results : int
Number of results requested
encoding_time : float
Time in seconds to encode the query
db_query_time : float
Time in seconds for the database query
"""
try:
# Handle user having no uploaded docs or doc filtering returning no matches
if results:
distances = [r["distance"] for r in results]
SemanticSearchUsage.objects.create(
Expand All @@ -113,11 +129,10 @@ def get_closest_embeddings(
num_results_returned=len(results),
max_distance=max(distances),
median_distance=median(distances),
min_distance=min(distances)
min_distance=min(distances),
)
else:
logger.warning("Semantic search returned no results")

SemanticSearchUsage.objects.create(
query_text=message_data,
user=user if (user and user.is_authenticated) else None,
Expand All @@ -129,9 +144,58 @@ def get_closest_embeddings(
num_results_returned=0,
max_distance=None,
median_distance=None,
min_distance=None
min_distance=None,
)
except Exception as e:
logger.error(f"Failed to create semantic search usage database record: {e}")


def get_closest_embeddings(
user, message_data, document_name=None, guid=None, num_results=10
):
"""
Find the closest embeddings to a given message for a specific user.

Parameters
----------
user : User
The user whose uploaded documents will be searched
message_data : str
The input message to find similar embeddings for
document_name : str, optional
Filter results to a specific document name
guid : str, optional
Filter results to a specific document GUID (takes precedence over document_name)
num_results : int, default 10
Maximum number of results to return

Returns
-------
list[dict]
List of dictionaries containing embedding results with keys:
- name: document name
- text: embedded text content
- page_number: page number in source document
- chunk_number: chunk number within the document
- distance: L2 distance from query embedding
- file_id: GUID of the source file

Notes
-----
Creates a SemanticSearchUsage record. Swallows exceptions so search isn't interrupted.
"""
encoding_start = time.time()
model = TransformerModel.get_instance().model
embedding_vector = model.encode(message_data)
encoding_time = time.time() - encoding_start

db_query_start = time.time()
queryset = build_query(user, embedding_vector, document_name, guid, num_results)
results = evaluate_query(queryset)
db_query_time = time.time() - db_query_start

log_usage(
results, message_data, user, guid, document_name, num_results, encoding_time, db_query_time
)

return results
Loading
Loading