From 72f5795ef200883a72108b080ba9539386781dd4 Mon Sep 17 00:00:00 2001 From: cka-y Date: Thu, 14 May 2026 14:24:04 -0400 Subject: [PATCH 1/5] initial commit --- .github/workflows/api-deployer.yml | 8 +- .gitignore | 5 +- infra/main.tf | 10 + infra/mcp/main.tf | 142 ++++++++++++ infra/mcp/vars.tf | 40 ++++ infra/vars.tf | 5 + infra/vars.tfvars.rename_me | 1 + mcp/.env.rename_me | 4 + mcp/.gitignore | 1 + mcp/Dockerfile | 20 ++ mcp/README.md | 207 +++++++++++++++++ mcp/mcp_config.json | 5 + mcp/pyproject.toml | 3 + mcp/requirements.txt | 7 + mcp/run.sh | 3 + mcp/src/__init__.py | 0 mcp/src/client.py | 12 + mcp/src/gtfs_cache.py | 89 +++++++ mcp/src/main.py | 22 ++ mcp/src/rules_cache.py | 126 ++++++++++ mcp/src/tools/__init__.py | 0 mcp/src/tools/get_validation_results.py | 173 ++++++++++++++ mcp/src/tools/query_gtfs.py | 258 +++++++++++++++++++++ mcp/src/tools/search_feeds.py | 135 +++++++++++ mcp/tests/__init__.py | 0 mcp/tests/test_config.json | 4 + mcp/tests/test_get_validation_results.py | 283 +++++++++++++++++++++++ mcp/tests/test_query_gtfs.py | 236 +++++++++++++++++++ mcp/tests/test_search_feeds.py | 246 ++++++++++++++++++++ scripts/docker-build-push.sh | 23 +- scripts/function-python-setup.sh | 45 +++- 31 files changed, 2109 insertions(+), 4 deletions(-) create mode 100644 infra/mcp/main.tf create mode 100644 infra/mcp/vars.tf create mode 100644 mcp/.env.rename_me create mode 100644 mcp/.gitignore create mode 100644 mcp/Dockerfile create mode 100644 mcp/README.md create mode 100644 mcp/mcp_config.json create mode 100644 mcp/pyproject.toml create mode 100644 mcp/requirements.txt create mode 100755 mcp/run.sh create mode 100644 mcp/src/__init__.py create mode 100644 mcp/src/client.py create mode 100644 mcp/src/gtfs_cache.py create mode 100644 mcp/src/main.py create mode 100644 mcp/src/rules_cache.py create mode 100644 mcp/src/tools/__init__.py create mode 100644 mcp/src/tools/get_validation_results.py create mode 100644 mcp/src/tools/query_gtfs.py create mode 100644 mcp/src/tools/search_feeds.py create mode 100644 mcp/tests/__init__.py create mode 100644 mcp/tests/test_config.json create mode 100644 mcp/tests/test_get_validation_results.py create mode 100644 mcp/tests/test_query_gtfs.py create mode 100644 mcp/tests/test_search_feeds.py diff --git a/.github/workflows/api-deployer.yml b/.github/workflows/api-deployer.yml index f5f56a361..87410ae30 100644 --- a/.github/workflows/api-deployer.yml +++ b/.github/workflows/api-deployer.yml @@ -237,6 +237,11 @@ jobs: DOCKER_IMAGE_VERSION=$EXTRACTED_VERSION.$FEED_API_IMAGE_VERSION scripts/docker-build-push.sh -project_id $PROJECT_ID -repo_name feeds-$ENVIRONMENT -service feed-api -region $REGION -version $DOCKER_IMAGE_VERSION + - name: Build & Publish MCP Server Docker Image + run: | + MCP_IMAGE_VERSION=$EXTRACTED_VERSION.$FEED_API_IMAGE_VERSION + scripts/docker-build-push.sh -project_id $PROJECT_ID -repo_name feeds-$ENVIRONMENT -service mcp-server -region $REGION -version $MCP_IMAGE_VERSION -dockerfile mcp/Dockerfile -context . + terraform-deploy: runs-on: ubuntu-latest permissions: write-all @@ -293,6 +298,7 @@ jobs: echo "ENVIRONMENT=${{ inputs.ENVIRONMENT }}" >> $GITHUB_ENV echo "DEPLOYER_SERVICE_ACCOUNT=${{ inputs.DEPLOYER_SERVICE_ACCOUNT }}" >> $GITHUB_ENV echo "FEED_API_IMAGE_VERSION=$EXTRACTED_VERSION.${{ inputs.FEED_API_IMAGE_VERSION }}" >> $GITHUB_ENV + echo "MCP_IMAGE_VERSION=$EXTRACTED_VERSION.${{ inputs.FEED_API_IMAGE_VERSION }}" >> $GITHUB_ENV echo "OAUTH2_CLIENT_ID=${{ secrets.OAUTH2_CLIENT_ID }}" >> $GITHUB_ENV echo "OAUTH2_CLIENT_SECRET=${{ secrets.OAUTH2_CLIENT_SECRET }}" >> $GITHUB_ENV echo "GLOBAL_RATE_LIMIT_REQ_PER_MINUTE=${{ inputs.GLOBAL_RATE_LIMIT_REQ_PER_MINUTE }}" >> $GITHUB_ENV @@ -321,7 +327,7 @@ jobs: - name: Populate Variables run: | scripts/replace-variables.sh -in_file infra/backend.conf.rename_me -out_file infra/backend.conf -variables BUCKET_NAME,OBJECT_PREFIX - scripts/replace-variables.sh -in_file infra/vars.tfvars.rename_me -out_file infra/vars.tfvars -variables PROJECT_ID,REGION,ENVIRONMENT,DEPLOYER_SERVICE_ACCOUNT,FEED_API_IMAGE_VERSION,OAUTH2_CLIENT_ID,OAUTH2_CLIENT_SECRET,GLOBAL_RATE_LIMIT_REQ_PER_MINUTE,ARTIFACT_REPO_NAME,VALIDATOR_ENDPOINT,TRANSITLAND_API_KEY,OPERATIONS_OAUTH2_CLIENT_ID,TDG_API_TOKEN,WEB_APP_REVALIDATE_URL,WEB_APP_REVALIDATE_SECRET + scripts/replace-variables.sh -in_file infra/vars.tfvars.rename_me -out_file infra/vars.tfvars -variables PROJECT_ID,REGION,ENVIRONMENT,DEPLOYER_SERVICE_ACCOUNT,FEED_API_IMAGE_VERSION,MCP_IMAGE_VERSION,OAUTH2_CLIENT_ID,OAUTH2_CLIENT_SECRET,GLOBAL_RATE_LIMIT_REQ_PER_MINUTE,ARTIFACT_REPO_NAME,VALIDATOR_ENDPOINT,TRANSITLAND_API_KEY,OPERATIONS_OAUTH2_CLIENT_ID,TDG_API_TOKEN,WEB_APP_REVALIDATE_URL,WEB_APP_REVALIDATE_SECRET - uses: hashicorp/setup-terraform@v3 with: diff --git a/.gitignore b/.gitignore index bd8f43cf6..8c2d96b6e 100644 --- a/.gitignore +++ b/.gitignore @@ -86,4 +86,7 @@ functions-python/**/*.csv *.code-workspace # Ignore OpenApi local backup files -*.yaml.bak \ No newline at end of file +*.yaml.bak + +# Claude folder +.claude \ No newline at end of file diff --git a/infra/main.tf b/infra/main.tf index ac33be87b..5fd6b98bd 100644 --- a/infra/main.tf +++ b/infra/main.tf @@ -100,6 +100,16 @@ module "feed-api" { } +module "mcp" { + source = "./mcp" + + project_id = var.project_id + gcp_region = var.gcp_region + environment = var.environment + docker_repository_name = "${var.artifact_repo_name}-${var.environment}" + mcp_image_version = var.mcp_image_version +} + module "functions-python" { source = "./functions-python" project_id = var.project_id diff --git a/infra/mcp/main.tf b/infra/mcp/main.tf new file mode 100644 index 000000000..4acfdf1e8 --- /dev/null +++ b/infra/mcp/main.tf @@ -0,0 +1,142 @@ +# +# MobilityData 2024 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# This script deploys the MCP Server cloud run service. +# The cloud run service is created with name: mcp-server-${var.environment} +# Module output: +# mcp_server_uri: URI of the MCP Server Cloud Run service + +data "google_project" "project" {} + +locals { + vpc_connector_name = lower(var.environment) == "dev" ? "vpc-connector-qa" : "vpc-connector-${lower(var.environment)}" + vpc_connector_project = lower(var.environment) == "dev" ? "mobility-feeds-qa" : var.project_id + + service_account_roles = [ + # Cloud Logging: allows writing logs to GCP + "roles/logging.logWriter", + # Cloud Trace: allows writing trace and span data + "roles/cloudtrace.agent", + # Cloud Monitoring: allows publishing custom metrics + "roles/monitoring.metricWriter", + # Serverless VPC Access: required to use a VPC connector + "roles/vpcaccess.user", + ] + + service_account_role_bindings = { + for role in local.service_account_roles : + "${role}" => { + role = role + project = var.project_id + } + } +} + +data "google_vpc_access_connector" "vpc_connector" { + name = local.vpc_connector_name + region = var.gcp_region + project = local.vpc_connector_project +} + +resource "google_service_account" "mcp_service_account" { + account_id = "mcp-service-account" + display_name = "MCP Server Service Account" +} + +resource "google_cloud_run_v2_service" "mcp_server" { + name = "mcp-server-${var.environment}" + location = var.gcp_region + ingress = "INGRESS_TRAFFIC_ALL" + + template { + service_account = google_service_account.mcp_service_account.email + + # Use annotations for max instances on older provider(<6.0.0) + annotations = { + "run.googleapis.com/max-instances" = "10" + } + + vpc_access { + connector = data.google_vpc_access_connector.vpc_connector.id + egress = "ALL_TRAFFIC" + } + + containers { + image = "${var.gcp_region}-docker.pkg.dev/${var.project_id}/${var.docker_repository_name}/mcp-server:${var.mcp_image_version}" + + env { + name = "FEEDS_DATABASE_URL" + value_source { + secret_key_ref { + secret = "${upper(var.environment)}_FEEDS_DATABASE_URL" + version = "latest" + } + } + } + + env { + name = "DATASETS_BUCKET_URL" + value = "https://storage.googleapis.com/mobilitydata-datasets-${lower(var.environment)}" + } + + resources { + limits = { + cpu = "1" + memory = "512Mi" + } + } + } + } +} + +data "google_iam_policy" "noauth" { + binding { + role = "roles/run.invoker" + members = ["allUsers"] + } +} + +resource "google_cloud_run_service_iam_policy" "noauth" { + location = google_cloud_run_v2_service.mcp_server.location + project = google_cloud_run_v2_service.mcp_server.project + service = google_cloud_run_v2_service.mcp_server.name + policy_data = data.google_iam_policy.noauth.policy_data +} + +resource "google_secret_manager_secret_iam_member" "feeds_db_url_access" { + project = var.project_id + secret_id = "${upper(var.environment)}_FEEDS_DATABASE_URL" + role = "roles/secretmanager.secretAccessor" + member = "serviceAccount:${google_service_account.mcp_service_account.email}" +} + +resource "google_project_iam_member" "mcp_service_account_roles" { + for_each = local.service_account_role_bindings + + project = each.value.project + role = each.value.role + member = "serviceAccount:${google_service_account.mcp_service_account.email}" +} + +output "mcp_server_uri" { + value = google_cloud_run_v2_service.mcp_server.uri + description = "URI of the MCP Server Cloud Run service" +} + +output "mcp_server_name" { + value = google_cloud_run_v2_service.mcp_server.name + description = "Name of the MCP Server Cloud Run service" +} diff --git a/infra/mcp/vars.tf b/infra/mcp/vars.tf new file mode 100644 index 000000000..907c9a78a --- /dev/null +++ b/infra/mcp/vars.tf @@ -0,0 +1,40 @@ +# +# MobilityData 2024 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +variable "project_id" { + type = string + description = "GCP project ID" +} + +variable "gcp_region" { + type = string + description = "GCP region" +} + +variable "environment" { + type = string + description = "Deployment environment (prod, staging, dev)" +} + +variable "docker_repository_name" { + type = string + description = "Artifact Registry Docker repository name" +} + +variable "mcp_image_version" { + type = string + description = "Docker image version/tag for the MCP server" +} diff --git a/infra/vars.tf b/infra/vars.tf index b62fac55b..7d42d919d 100644 --- a/infra/vars.tf +++ b/infra/vars.tf @@ -89,4 +89,9 @@ variable "web_app_revalidate_secret" { description = "Secret token used to authenticate requests to the website revalidation endpoint" sensitive = true default = "" +} + +variable "mcp_image_version" { + type = string + description = "Docker image version/tag for the MCP server" } \ No newline at end of file diff --git a/infra/vars.tfvars.rename_me b/infra/vars.tfvars.rename_me index 1aab1eae3..74dc1e59d 100644 --- a/infra/vars.tfvars.rename_me +++ b/infra/vars.tfvars.rename_me @@ -11,6 +11,7 @@ artifact_repo_name = {{ARTIFACT_REPO_NAME}} deployer_service_account = {{DEPLOYER_SERVICE_ACCOUNT}} feed_api_image_version = {{FEED_API_IMAGE_VERSION}} +mcp_image_version = {{MCP_IMAGE_VERSION}} oauth2_client_id = {{OAUTH2_CLIENT_ID}} oauth2_client_secret = {{OAUTH2_CLIENT_SECRET}} diff --git a/mcp/.env.rename_me b/mcp/.env.rename_me new file mode 100644 index 000000000..2f40301c2 --- /dev/null +++ b/mcp/.env.rename_me @@ -0,0 +1,4 @@ +FEEDS_DATABASE_URL=postgresql://user:pass@localhost:5432/feeds +DATASETS_BUCKET_URL=https://storage.googleapis.com/mobilitydata-datasets-prod +PORT=8080 +FEED_CACHE_TTL_SECONDS=600 diff --git a/mcp/.gitignore b/mcp/.gitignore new file mode 100644 index 000000000..26bcf9d8c --- /dev/null +++ b/mcp/.gitignore @@ -0,0 +1 @@ +shared \ No newline at end of file diff --git a/mcp/Dockerfile b/mcp/Dockerfile new file mode 100644 index 000000000..351181286 --- /dev/null +++ b/mcp/Dockerfile @@ -0,0 +1,20 @@ +FROM python:3.11-slim + +WORKDIR /app + +# Copy requirements first for layer caching +COPY requirements.txt . +RUN pip install --no-cache-dir -r requirements.txt + +# Copy shared code from api (symlinks won't work in Docker, copy directly) +COPY api/src/shared /app/shared + +# Copy MCP server source +COPY mcp/src/ /app/ + +# Set Python path so imports work +ENV PYTHONPATH=/app + +EXPOSE 8080 + +CMD ["python", "main.py"] diff --git a/mcp/README.md b/mcp/README.md new file mode 100644 index 000000000..db170eff4 --- /dev/null +++ b/mcp/README.md @@ -0,0 +1,207 @@ +# MobilityDatabase MCP Server + +An [MCP (Model Context Protocol)](https://modelcontextprotocol.io) server that exposes MobilityDatabase's GTFS/GBFS data as AI-accessible tools. Connect it to Claude Desktop (or any MCP client) to query feeds, search by location, and explore transit data conversationally. + +## What's been implemented + +**Step 1 — Smart Search (`search_feeds` tool)** + +Searches the Mobility Database using PostgreSQL full-text search against the `FeedSearch` materialized view. Returns rich location and metadata context so the AI can disambiguate results — e.g. distinguishing "Montréal, QC, Canada" from "Montréal-du-Gers, France." + +| Parameter | Type | Default | Description | +|---|---|---|---| +| `search_query` | string | — | Free-text search (e.g. `"Montreal"`, `"STM"`, `"Japan"`) | +| `data_type` | string | `gtfs` | `gtfs`, `gtfs_rt`, or `gbfs` | +| `status` | string | none | `active`, `inactive`, or `deprecated` | +| `is_official` | boolean | none | Filter to official feeds only | +| `limit` | integer | `30` | Max results | + +**Step 2 — Validation Results (`get_validation_results` tool)** + +Returns per-feed validation results enriched with: +- Rule documentation (description + link to official validator docs) +- Sample rows from the affected GTFS file (fetched from GCS public URLs) +- Full validation summary (error/warning/info counts) + +| Parameter | Type | Default | Description | +|---|---|---|---| +| `feed_id` | string | — | Mobility Database feed ID (e.g. `mdb-1210`) | +| `severity_filter` | string | `all` | `all`, `errors`, `warnings`, or `info` | + +> Rule documentation is fetched from `https://gtfs-validator.mobilitydata.org/rules.json` on first use and cached in memory for 24 hours. No manual refresh needed. + +**Step 3 — GTFS SQL Query Engine (`query_gtfs` tool)** + +Loads a feed's extracted GTFS files into an in-memory DuckDB database and executes SQL queries against them. Results are cached per feed for 10 minutes so follow-up queries are instant. + +| Parameter | Type | Default | Description | +|---|---|---|---| +| `feed_id` | string | — | Mobility Database feed ID (e.g. `mdb-1210`) | +| `query` | string | — | `SCHEMA` to list tables/columns, or any SQL `SELECT` statement | + +> Use `SCHEMA` first to discover available tables, then write SQL to answer your question. Only `SELECT` queries are allowed. + +## Running locally + +### Prerequisites + +- Python 3.11+ +- Access to the MobilityDatabase PostgreSQL instance (or a local copy) + +### Setup + +```bash +# 1. From the repo root, link shared code into mcp/src/shared/ +./scripts/function-python-setup.sh --mcp + +# 2. Install dependencies +cd mcp +pip install -r requirements.txt + +# 3. Configure environment +cp .env.rename_me .env +# Edit .env and set FEEDS_DATABASE_URL to your database connection string + +# 4. Start the server +cd src +PYTHONPATH=. python main.py stdio +``` + +### Environment variables + +| Variable | Description | Example | +|---|---|---| +| `FEEDS_DATABASE_URL` | PostgreSQL connection string for Mobility Database data | `postgresql://user:pass@host:5432/feeds` | +| `DATASETS_BUCKET_URL` | Base URL for GCS-hosted extracted GTFS files. Used to fetch sample rows when enriching validation notices. No auth required — public bucket. | `https://storage.googleapis.com/mobilitydata-datasets-prod` | +| `FEED_CACHE_TTL_SECONDS` | How long to cache loaded GTFS datasets in memory (seconds) | `600` | +| `PORT` | Server port for SSE transport | `8080` | + +```bash +DATASETS_BUCKET_URL=https://storage.googleapis.com/mobilitydata-datasets-prod +``` + +### Running tests + +```bash +cd mcp +python -m pytest tests/ -v +``` + +## Running with Docker + +The Dockerfile build context is the **repo root** (needed to copy shared code): + +```bash +# Build +docker build -f mcp/Dockerfile -t mcp-server . + +# Run +docker run -p 8080:8080 \ + -e FEEDS_DATABASE_URL="postgresql://user:pass@host:5432/feeds" \ + -e DATASETS_BUCKET_URL="https://storage.googleapis.com/mobilitydata-datasets-prod" \ + mcp-server +``` + +## Connecting to Claude Desktop + +Add the following to your Claude Desktop config (`~/Library/Application Support/Claude/claude_desktop_config.json` on macOS). + +Replace `/path/to/mobility-feed-api` with your actual repo path, and update the environment values as needed: + +```json +{ + "mcpServers": { + "mobilitydatabase": { + "command": "python", + "args": ["main.py", "stdio"], + "cwd": "/path/to/mobility-feed-api/mcp/src", + "env": { + "PYTHONPATH": "/path/to/mobility-feed-api/mcp/src", + "FEEDS_DATABASE_URL": "postgresql://user:pass@localhost:5432/feeds", + "DATASETS_BUCKET_URL": "https://storage.googleapis.com/mobilitydata-datasets-prod", + "FEED_CACHE_TTL_SECONDS": "600" + } + } + } +} +``` + +> If you use a virtual environment, point `command` to the venv's Python binary (e.g. `/path/to/mobility-feed-api/mcp/.venv/bin/python`). + +Restart Claude Desktop. You should see the `search_feeds`, `get_validation_results`, and `query_gtfs` tools available. Try: +> *"Find active GTFS feeds in Montreal"* +> *"Search for STM feeds"* +> *"Show me official feeds in Japan"* +> *"What validation errors does mdb-1210 have?"* +> *"Show me the errors in the STM feed"* + +## Questions that feel like magic + +### Single-tool deep dives + +Once Claude can query raw GTFS tables directly, you can ask things that would be painful to answer by hand: + +> *"Which route in the STM network serves the most unique stops?"* +> *"Are there any trips that depart after midnight? List them with their headsigns."* +> *"What percentage of stops in the Tokyo feed have wheelchair boarding info?"* +> *"Which agency in the feed operates the most distinct routes?"* +> *"Find all stops within 500m of each other that belong to different routes — potential transfer point opportunities."* +> *"How many trips run on weekdays vs weekends for each route?"* +> *"What's the average dwell time between consecutive stops on the busiest route?"* +> *"Show me routes that have no shapes defined — they'll render as straight lines on a map."* +> *"Which stops appear in stop_times but are missing from stops.txt?"* +> *"What's the earliest first departure and latest last departure across all routes?"* + +These are the kinds of questions that become conversational once an AI can inspect `trips.txt`, `stop_times.txt`, `routes.txt`, `calendar.txt`, `shapes.txt`, and the rest through SQL. + +### Cross-tool investigations + +The real power emerges when the tools chain together. Claude will figure out the sequence — you just ask the question. + +> *"Find all active GTFS feeds in California, then tell me which ones have validation errors and what those errors are."* +> → `search_feeds` (California) → `get_validation_results` for each → summary report + +> *"I'm traveling in Tokyo next month — which feed covers the Tokyo Metro, does it have any errors that could affect trip planning, and what routes serve Shinjuku station?"* +> → `search_feeds` (Tokyo Metro) → `get_validation_results` → `query_gtfs` (stops near Shinjuku) + +> *"Compare the data quality of the top 5 transit agencies in Canada — who has the cleanest feed?"* +> → `search_feeds` (Canada, limit=5) → `get_validation_results` for each → ranked comparison by error count + +> *"The STM feed has a `stop_times_with_only_arrival_time` warning — can you show me which trips are affected and what the schedule looks like for those stops?"* +> → `get_validation_results` (mdb-956, warnings) → `query_gtfs` (SELECT from stop_times WHERE arrival_time IS NOT NULL AND departure_time IS NULL) + +> *"Are there any official feeds in Europe with zero validation errors? If so, what GTFS features do they support?"* +> → `search_feeds` (Europe, is_official=true) → `get_validation_results` for each → filter zero errors → list features + +> *"Find the feed for the Paris Métro, check if it has wheelchair accessibility data, and count what fraction of stops actually have it filled in."* +> → `search_feeds` (Paris Métro) → `get_validation_results` (check features list) → `query_gtfs` (SELECT COUNT(*) by wheelchair_boarding value) + +> *"Which city in Japan has the most comprehensive GTFS feed — most routes, most stops, fewest errors?"* +> → `search_feeds` (Japan) → `get_validation_results` for each → `query_gtfs` SCHEMA on top candidates → compare route/stop counts + +## Architecture + +``` +Claude Desktop (or any MCP client) + │ MCP Protocol (stdio locally, SSE when deployed) + ▼ + MCP Server (Python) + ├──► SQLAlchemy → PostgreSQL (feed metadata, search, validation reports) + └──► DuckDB (in-memory) ← GTFS CSVs fetched from GCS public URLs +``` + +The server connects **directly to the database** — it does not call the public Feed API. It reuses the shared database models and query logic from `api/src/shared/` (linked via symlinks in `src/shared/`). + +## Deployment + +Terraform infrastructure is in `infra/mcp/`. The module creates a Cloud Run service (`mcp-server-{env}`) and is wired into the root `infra/main.tf`. Deploy by building and pushing the Docker image to Artifact Registry, then running terraform: + +```bash +# Build and push (from repo root) +docker build -f mcp/Dockerfile -t {region}-docker.pkg.dev/{project}/{repo}/mcp-server:{version} . +docker push {region}-docker.pkg.dev/{project}/{repo}/mcp-server:{version} + +# Apply terraform +cd infra +terraform apply -var="mcp_image_version={version}" +``` diff --git a/mcp/mcp_config.json b/mcp/mcp_config.json new file mode 100644 index 000000000..7830f325c --- /dev/null +++ b/mcp/mcp_config.json @@ -0,0 +1,5 @@ +{ + "name": "mcp-server", + "description": "MobilityDatabase MCP Server", + "include_api_folders": ["common", "database", "database_gen", "feed_filters", "db_models"] +} diff --git a/mcp/pyproject.toml b/mcp/pyproject.toml new file mode 100644 index 000000000..da606462f --- /dev/null +++ b/mcp/pyproject.toml @@ -0,0 +1,3 @@ +[tool.pytest.ini_options] +testpaths = ["tests"] +pythonpath = ["src"] diff --git a/mcp/requirements.txt b/mcp/requirements.txt new file mode 100644 index 000000000..095c71dd5 --- /dev/null +++ b/mcp/requirements.txt @@ -0,0 +1,7 @@ +mcp>=1.0.0 +sqlalchemy>=2.0.0 +psycopg2-binary>=2.9.0 +geoalchemy2>=0.14.0 +python-dotenv>=1.0.0 +duckdb>=0.10.0 +httpx>=0.24.0 diff --git a/mcp/run.sh b/mcp/run.sh new file mode 100755 index 000000000..436c41d8f --- /dev/null +++ b/mcp/run.sh @@ -0,0 +1,3 @@ +#!/usr/bin/env bash +cd "$(dirname "$0")/src" +exec "$(dirname "$0")/.venv/bin/python" main.py stdio diff --git a/mcp/src/__init__.py b/mcp/src/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/mcp/src/client.py b/mcp/src/client.py new file mode 100644 index 000000000..beaa1b6f0 --- /dev/null +++ b/mcp/src/client.py @@ -0,0 +1,12 @@ +import asyncio +from mcp import ClientSession +from mcp.client.sse import sse_client + +async def test(): + async with sse_client("http://localhost:8080/sse") as (read, write): + async with ClientSession(read, write) as session: + await session.initialize() + tools = await session.list_tools() + print([t.name for t in tools.tools]) + +asyncio.run(test()) \ No newline at end of file diff --git a/mcp/src/gtfs_cache.py b/mcp/src/gtfs_cache.py new file mode 100644 index 000000000..d5f10fbc4 --- /dev/null +++ b/mcp/src/gtfs_cache.py @@ -0,0 +1,89 @@ +import os +import threading +import time +from typing import Callable, Optional + +import duckdb + +DEFAULT_TTL_SECONDS = int(os.getenv("FEED_CACHE_TTL_SECONDS", "600")) + + +class _CacheEntry: + __slots__ = ("connection", "loaded_at", "lock") + + def __init__(self): + self.connection: Optional[duckdb.DuckDBPyConnection] = None + self.loaded_at: float = 0.0 + self.lock = threading.Lock() + + +class GtfsCache: + """Thread-safe singleton TTL cache for GTFS DuckDB connections.""" + + _instance: Optional["GtfsCache"] = None + _instance_lock = threading.Lock() + + def __new__(cls, ttl_seconds: int = DEFAULT_TTL_SECONDS) -> "GtfsCache": + if cls._instance is None: + with cls._instance_lock: + if cls._instance is None: + cls._instance = super().__new__(cls) + cls._instance._ttl_seconds = ttl_seconds + cls._instance._cache: dict[tuple[str, str], _CacheEntry] = {} + cls._instance._cache_lock = threading.Lock() + return cls._instance + + def _is_fresh(self, entry: _CacheEntry) -> bool: + return ( + entry.connection is not None + and time.monotonic() - entry.loaded_at <= self._ttl_seconds + ) + + def get_or_load( + self, + feed_id: str, + dataset_id: str, + loader_fn: Callable[[], duckdb.DuckDBPyConnection], + ) -> duckdb.DuckDBPyConnection: + key = (feed_id, dataset_id) + entry = self._cache.get(key) + if entry and self._is_fresh(entry): + return entry.connection + + if entry is None: + with self._cache_lock: + entry = self._cache.get(key) + if entry is None: + entry = _CacheEntry() + self._cache[key] = entry + + old_connection: Optional[duckdb.DuckDBPyConnection] = None + with entry.lock: + if self._is_fresh(entry): + return entry.connection + old_connection = entry.connection + entry.connection = loader_fn() + entry.loaded_at = time.monotonic() + connection = entry.connection + + if old_connection is not None and old_connection is not connection: + try: + old_connection.close() + except Exception: + pass + + return connection + + +_gtfs_cache: Optional[GtfsCache] = None +_gtfs_cache_lock = threading.Lock() + + +def get_gtfs_cache() -> GtfsCache: + """Return the module-level GTFS cache singleton.""" + global _gtfs_cache + if _gtfs_cache is None: + with _gtfs_cache_lock: + if _gtfs_cache is None: + _gtfs_cache = GtfsCache() + return _gtfs_cache diff --git a/mcp/src/main.py b/mcp/src/main.py new file mode 100644 index 000000000..afd135162 --- /dev/null +++ b/mcp/src/main.py @@ -0,0 +1,22 @@ +import os +import sys +from dotenv import load_dotenv + +load_dotenv(os.path.join(os.path.dirname(__file__), "..", ".env")) + +from mcp.server.fastmcp import FastMCP +from tools.get_validation_results import get_validation_results_tool +from tools.query_gtfs import query_gtfs_tool +from tools.search_feeds import search_feeds_tool + +port = int(os.getenv("PORT", "8080")) + +# host and port must be passed to the constructor in mcp 1.x (not to run()) +mcp = FastMCP("mobilitydatabase-mcp", host="0.0.0.0", port=port) +mcp.tool()(search_feeds_tool) +mcp.tool()(get_validation_results_tool) +mcp.tool()(query_gtfs_tool) + +if __name__ == "__main__": + transport = sys.argv[1] if len(sys.argv) > 1 else "sse" + mcp.run(transport=transport) diff --git a/mcp/src/rules_cache.py b/mcp/src/rules_cache.py new file mode 100644 index 000000000..80e2e1e0c --- /dev/null +++ b/mcp/src/rules_cache.py @@ -0,0 +1,126 @@ +""" +In-memory TTL cache for GTFS validator rule documentation. +Fetches rules.json from the official validator site on first use, +then serves from memory until TTL expires (default: 24 hours). +""" +import logging +import threading +import time +from typing import Optional + +import httpx + +logger = logging.getLogger(__name__) + +RULES_JSON_URL = "https://gtfs-validator.mobilitydata.org/rules.json" +RULES_BASE_URL = "https://gtfs-validator.mobilitydata.org/rules.html" +DEFAULT_TTL_SECONDS = 24 * 3600 # 24 hours + + +class RuleDoc: + """Structured rule documentation entry.""" + + __slots__ = ("code", "short_summary", "description", "affected_files", "rule_url") + + def __init__( + self, + code: str, + short_summary: str, + description: str, + affected_files: list[str], + rule_url: str, + ): + self.code = code + self.short_summary = short_summary + self.description = description + self.affected_files = affected_files + self.rule_url = rule_url + + def to_dict(self) -> dict: + return { + "description": self.description, + "short_summary": self.short_summary, + "affected_files": self.affected_files, + "rule_url": self.rule_url, + } + + +class RulesCache: + """Thread-safe singleton cache for GTFS validator rule docs.""" + + _instance: Optional["RulesCache"] = None + _lock = threading.Lock() + + def __new__(cls, ttl_seconds: int = DEFAULT_TTL_SECONDS) -> "RulesCache": + with cls._lock: + if cls._instance is None: + cls._instance = super().__new__(cls) + cls._instance._ttl_seconds = ttl_seconds + cls._instance._cache: dict[str, RuleDoc] = {} + cls._instance._fetched_at: float = 0.0 + cls._instance._fetch_lock = threading.Lock() + return cls._instance + + def _is_stale(self) -> bool: + return time.monotonic() - self._fetched_at > self._ttl_seconds + + def _fetch_and_build(self) -> None: + """Fetch rules.json and build the lookup dict. Called under _fetch_lock.""" + logger.info("Fetching GTFS validator rules from %s", RULES_JSON_URL) + try: + with httpx.Client(timeout=30.0) as client: + response = client.get(RULES_JSON_URL) + response.raise_for_status() + raw: dict = response.json() + except Exception as exc: + logger.error("Failed to fetch rules.json: %s", exc) + # Keep serving stale data if we have it, otherwise propagate + if not self._cache: + raise + logger.warning("Serving stale rules cache due to fetch failure") + return + + new_cache: dict[str, RuleDoc] = {} + for code, entry in raw.items(): + refs = entry.get("references", {}) + file_refs: list[str] = refs.get("fileReferences", []) or [] + new_cache[code] = RuleDoc( + code=code, + short_summary=entry.get("shortSummary", ""), + description=entry.get("description", ""), + affected_files=file_refs, + rule_url=f"{RULES_BASE_URL}#{code}-rule", + ) + + self._cache = new_cache + self._fetched_at = time.monotonic() + logger.info("Rules cache populated: %d rules loaded", len(new_cache)) + + def get(self, notice_code: str) -> Optional[RuleDoc]: + """Look up a rule by notice code. Fetches/refreshes cache as needed.""" + if self._is_stale(): + with self._fetch_lock: + # Double-check after acquiring lock + if self._is_stale(): + self._fetch_and_build() + return self._cache.get(notice_code) + + def get_dict(self, notice_code: str) -> Optional[dict]: + """Return the rule as a plain dict, or None if unknown.""" + rule = self.get(notice_code) + return rule.to_dict() if rule else None + + @property + def size(self) -> int: + return len(self._cache) + + +_rules_cache: Optional[RulesCache] = None + + +def get_rules_cache() -> RulesCache: + """Return the module-level RulesCache singleton.""" + global _rules_cache + if _rules_cache is None: + _rules_cache = RulesCache() + return _rules_cache diff --git a/mcp/src/tools/__init__.py b/mcp/src/tools/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/mcp/src/tools/get_validation_results.py b/mcp/src/tools/get_validation_results.py new file mode 100644 index 000000000..b225ffa1d --- /dev/null +++ b/mcp/src/tools/get_validation_results.py @@ -0,0 +1,173 @@ +""" +MCP tool: get_validation_results +Retrieves validation results for a GTFS feed enriched with rule docs and GTFS file samples. +""" +import csv +import io +import json +import logging +import os +from typing import Optional + +import httpx +from sqlalchemy.orm import joinedload + +from rules_cache import get_rules_cache +from shared.database.database import Database +from shared.database_gen.sqlacodegen_models import Gtfsdataset, Gtfsfeed, Validationreport + +logger = logging.getLogger(__name__) + +SAMPLE_ROWS = 10 +SEVERITY_MAP = { + "errors": "ERROR", + "warnings": "WARNING", + "info": "INFO", +} + + +def _fetch_gtfs_sample(feed_stable_id: str, dataset_stable_id: str, filename: str) -> Optional[dict]: + """Fetch a GTFS CSV file from GCS and return columns, sample rows, and total row count.""" + datasets_bucket_url = os.getenv("DATASETS_BUCKET_URL", "") + if not datasets_bucket_url: + return None + + url = f"{datasets_bucket_url}/{feed_stable_id}/{dataset_stable_id}/extracted/{filename}" + try: + with httpx.Client(timeout=15.0) as client: + response = client.get(url) + if response.status_code == 404: + return None + response.raise_for_status() + content = response.text + except Exception as exc: + logger.warning("Failed to fetch GTFS file %s: %s", url, exc) + return None + + reader = csv.DictReader(io.StringIO(content)) + columns = reader.fieldnames or [] + rows = [] + total = 0 + for row in reader: + total += 1 + if len(rows) < SAMPLE_ROWS: + rows.append(dict(row)) + + return { + "filename": filename, + "source_url": url, + "columns": list(columns), + "sample_rows": rows, + "total_rows": total, + } + + +def get_validation_results_tool( + feed_id: str, + severity_filter: Optional[str] = "all", +) -> str: + """ + Retrieve validation results for a GTFS feed, enriched with rule documentation + and sample rows from the affected GTFS files. + + Args: + feed_id: Mobility Database feed ID (e.g. "mdb-1210") + severity_filter: Filter notices by severity. One of: all, errors, warnings, info. Default: all + + Returns: + JSON string with feed metadata, validation summary, and enriched notice list + """ + rules = get_rules_cache() + + db = Database() + with db.start_db_session() as session: + feed = ( + session.query(Gtfsfeed) + .filter(Gtfsfeed.stable_id == feed_id) + .options( + joinedload(Gtfsfeed.latest_dataset) + .joinedload(Gtfsdataset.validation_reports) + .joinedload(Validationreport.notices), + joinedload(Gtfsfeed.latest_dataset) + .joinedload(Gtfsdataset.validation_reports) + .joinedload(Validationreport.features), + ) + .first() + ) + + if feed is None: + return json.dumps({"error": f"Feed '{feed_id}' not found."}) + + dataset = feed.latest_dataset + if dataset is None: + return json.dumps({"error": f"Feed '{feed_id}' has no dataset yet."}) + + report: Optional[Validationreport] = None + if dataset.validation_reports: + report = max(dataset.validation_reports, key=lambda item: item.validated_at or 0) + + if report is None: + return json.dumps( + { + "feed_id": feed_id, + "provider": feed.provider, + "dataset_id": dataset.stable_id, + "error": "No validation report found for this feed.", + } + ) + + severity_value = SEVERITY_MAP.get((severity_filter or "all").lower()) + notices = report.notices + if severity_value: + notices = [notice for notice in notices if notice.severity == severity_value] + + file_sample_cache: dict[str, Optional[dict]] = {} + enriched_notices = [] + for notice in sorted(notices, key=lambda item: (item.severity or "", -(item.total_notices or 0))): + rule_doc = rules.get_dict(notice.notice_code) + + affected_file_data = None + affected_files = rule_doc.get("affected_files") if rule_doc else [] + if affected_files: + primary_file = affected_files[0] + if primary_file not in file_sample_cache: + file_sample_cache[primary_file] = _fetch_gtfs_sample( + feed_id, dataset.stable_id, primary_file + ) + affected_file_data = file_sample_cache[primary_file] + + enriched_notices.append( + { + "code": notice.notice_code, + "severity": notice.severity, + "total_notices": notice.total_notices, + "rule_doc": rule_doc, + "affected_file": affected_file_data, + } + ) + + feature_names = [feature.name for feature in report.features] + + result = { + "feed_id": feed_id, + "provider": feed.provider, + "dataset_id": dataset.stable_id, + "validated_at": str(report.validated_at) if report.validated_at else None, + "validator_version": report.validator_version, + "report_urls": { + "json": report.json_report, + "html": report.html_report, + }, + "gtfs_features": feature_names, + "summary": { + "total_errors": report.total_error, + "total_warnings": report.total_warning, + "total_info": report.total_info, + "unique_error_count": report.unique_error_count, + "unique_warning_count": report.unique_warning_count, + "unique_info_count": report.unique_info_count, + }, + "notices": enriched_notices, + } + + return json.dumps(result, default=str) diff --git a/mcp/src/tools/query_gtfs.py b/mcp/src/tools/query_gtfs.py new file mode 100644 index 000000000..fbffb0040 --- /dev/null +++ b/mcp/src/tools/query_gtfs.py @@ -0,0 +1,258 @@ +import json +import logging +import os +import time +from typing import Optional + +import duckdb + +from gtfs_cache import get_gtfs_cache +from shared.database.database import Database +from shared.database_gen.sqlacodegen_models import Gtfsfeed + +logger = logging.getLogger(__name__) + +STANDARD_GTFS_FILES = [ + "agency.txt", + "stops.txt", + "routes.txt", + "trips.txt", + "stop_times.txt", + "calendar.txt", + "calendar_dates.txt", + "shapes.txt", + "fare_attributes.txt", + "fare_rules.txt", + "frequencies.txt", + "transfers.txt", + "feed_info.txt", + "pathways.txt", + "levels.txt", + "translations.txt", + "attributions.txt", +] +ROW_LIMIT = 1000 + + +def _quote_identifier(value: str) -> str: + return '"' + value.replace('"', '""') + '"' + + +def _table_name_for_file(filename: str) -> str: + return filename[:-4] if filename.endswith(".txt") else filename + + +def _file_name_for_table(table_name: str) -> str: + return f"{table_name}.txt" + + +STANDARD_TABLE_NAMES = frozenset(_table_name_for_file(f) for f in STANDARD_GTFS_FILES) + + +def _load_duckdb(feed_id: str, dataset_id: str, datasets_bucket_url: str, files: list[str]) -> duckdb.DuckDBPyConnection: + """Load GTFS files directly from GCS into an in-memory DuckDB via httpfs.""" + con = duckdb.connect() + con.install_extension("httpfs") + con.load_extension("httpfs") + + base_url = f"{datasets_bucket_url}/{feed_id}/{dataset_id}/extracted" + for filename in files: + table_name = _table_name_for_file(filename) + url = f"{base_url}/{filename}" + try: + con.execute( + f"CREATE TABLE {_quote_identifier(table_name)} AS " + f"SELECT * FROM read_csv_auto('{url}', all_varchar=true)" + ) + except Exception as exc: + logger.warning("Failed to load %s: %s", filename, exc) + + return con + + +def _resolve_dataset(feed_id: str) -> tuple[str | None, str | None]: + """Return (dataset_stable_id, error_json) for a given feed_id.""" + db = Database() + with db.start_db_session() as session: + feed = session.query(Gtfsfeed).filter(Gtfsfeed.stable_id == feed_id).first() + if feed is None: + return None, json.dumps({"error": f"Feed '{feed_id}' not found."}) + + dataset = feed.latest_dataset + if dataset is None: + return None, json.dumps({"error": f"Feed '{feed_id}' has no dataset yet."}) + + return dataset.stable_id, None + + +def _get_schema(con: duckdb.DuckDBPyConnection) -> tuple[dict, list[str]]: + tables = {} + table_names = [row[0] for row in con.execute("SHOW TABLES").fetchall()] + for table_name in sorted(table_names): + columns = [row[1] for row in con.execute(f"PRAGMA table_info({_quote_identifier(table_name)})").fetchall()] + row_count = con.execute(f"SELECT COUNT(*) FROM {_quote_identifier(table_name)}").fetchone()[0] + tables[table_name] = {"columns": columns, "row_count": row_count} + available_files = [f"{table_name}.txt" for table_name in sorted(table_names)] + return tables, available_files + + +def _extract_tables_from_query(query: str) -> list[str]: + """Extract referenced GTFS table names from a SQL query.""" + if not query: + return [] + upper = query.upper() + found = [] + for table_name in STANDARD_TABLE_NAMES: + if table_name.upper() in upper: + found.append(table_name) + return found + + +def _validate_files( + files: Optional[list[str]], query: Optional[str] = None +) -> tuple[list[str], Optional[str]]: + """Validate and normalize the files list. + + Returns (filenames, error). When *error* is not None the caller should + return it directly — *filenames* will be empty. + """ + user_provided = bool(files) + if not files and query and query.strip().upper() != "SCHEMA": + files = _extract_tables_from_query(query) + + if not files: + return list(STANDARD_GTFS_FILES), None + + normalized = [] + invalid = [] + for f in files: + name = f.strip() + if not name: + continue + table_name = _table_name_for_file(name) if name.endswith(".txt") else name + if table_name not in STANDARD_TABLE_NAMES: + invalid.append(name) + continue + normalized.append(_file_name_for_table(table_name)) + + if invalid and user_provided: + valid_list = ", ".join(sorted(STANDARD_TABLE_NAMES)) + return [], json.dumps({ + "error": f"Invalid GTFS file(s): {', '.join(invalid)}. " + f"Valid files are: {valid_list}", + }) + + return (normalized if normalized else list(STANDARD_GTFS_FILES)), None + + +def query_gtfs_tool(feed_id: str, query: str, files: Optional[list[str]] = None) -> str: + """ + Load a GTFS feed into an in-memory DuckDB database and execute SQL queries. + + Use query="SCHEMA" first to discover available tables and columns. + Then write SQL SELECT queries to answer questions about routes, stops, schedules, etc. + + Results are cached per feed/file-set for FEED_CACHE_TTL_SECONDS (default: 10 minutes). + + Args: + feed_id: Mobility Database feed ID (e.g. "mdb-1210") + query: Either "SCHEMA" to list tables/columns, or a SQL SELECT statement + files: Optional list of GTFS files to load (e.g. ["stops", "routes", "trips"]). + Accepts table names or filenames (e.g. "stops" or "stops.txt"). + If omitted, all standard GTFS files are loaded. + Tip: only load the tables you need for much faster responses. + + Returns: + JSON string with schema info or query results + """ + datasets_bucket_url = os.getenv("DATASETS_BUCKET_URL", "") + if not datasets_bucket_url: + return json.dumps({"error": "DATASETS_BUCKET_URL is not configured."}) + + dataset_id, error = _resolve_dataset(feed_id) + if error: + return error + logging.info("Querying GTFS feed %s dataset %s with files %s", feed_id, dataset_id, files) + target_files, validation_error = _validate_files(files, query) + if validation_error: + return validation_error + # Cache key includes sorted file set so different subsets are cached independently + cache_key_suffix = ",".join(sorted(target_files)) + cache = get_gtfs_cache() + started_at = time.perf_counter() + + try: + con = cache.get_or_load( + feed_id, + f"{dataset_id}:{cache_key_suffix}", + lambda: _load_duckdb(feed_id, dataset_id, datasets_bucket_url, target_files), + ) + except Exception as exc: + logger.exception("Failed to load GTFS feed %s dataset %s", feed_id, dataset_id) + return json.dumps( + { + "feed_id": feed_id, + "dataset_id": dataset_id, + "error": f"Failed to load GTFS feed: {exc}", + }, + default=str, + ) + + if (query or "").strip().upper() == "SCHEMA": + cursor = con.cursor() + try: + tables, available_files = _get_schema(cursor) + finally: + cursor.close() + return json.dumps( + { + "feed_id": feed_id, + "dataset_id": dataset_id, + "tables": tables, + "available_files": available_files, + }, + default=str, + ) + + normalized_query = (query or "").strip().rstrip(";") + if not normalized_query.upper().startswith("SELECT"): + return json.dumps( + { + "feed_id": feed_id, + "dataset_id": dataset_id, + "error": "Only SQL SELECT queries are allowed. Use SCHEMA to inspect tables.", + } + ) + + limited_sql = f"SELECT * FROM ({normalized_query}) AS _query_gtfs LIMIT {ROW_LIMIT}" + cursor = con.cursor() + try: + result = cursor.execute(limited_sql) + rows = [list(row) for row in result.fetchall()] + columns = [column[0] for column in (result.description or [])] + except Exception as exc: + return json.dumps( + { + "feed_id": feed_id, + "dataset_id": dataset_id, + "sql": limited_sql, + "error": f"Query failed: {exc}", + }, + default=str, + ) + finally: + cursor.close() + + execution_time_ms = int((time.perf_counter() - started_at) * 1000) + return json.dumps( + { + "feed_id": feed_id, + "dataset_id": dataset_id, + "sql": limited_sql, + "columns": columns, + "rows": rows, + "row_count": len(rows), + "execution_time_ms": execution_time_ms, + }, + default=str, + ) diff --git a/mcp/src/tools/search_feeds.py b/mcp/src/tools/search_feeds.py new file mode 100644 index 000000000..3101c86e2 --- /dev/null +++ b/mcp/src/tools/search_feeds.py @@ -0,0 +1,135 @@ +from typing import Optional +import json +from sqlalchemy import func, select, or_ +from shared.database.database import Database +from shared.database.sql_functions.unaccent import unaccent +from shared.database_gen.sqlacodegen_models import t_feedsearch + +feed_search_columns = [col for col in t_feedsearch.columns if col.name != "document"] + + +def get_parsed_search_tsquery(search_query: str): + parsed_query = f"{search_query.strip()}:*" if search_query and len(search_query.strip()) > 0 else "" + return func.plainto_tsquery("english", unaccent(parsed_query)) + + +def search_feeds_tool( + search_query: str, + data_type: Optional[str] = "gtfs", + status: Optional[str] = None, + is_official: Optional[bool] = None, + limit: Optional[int] = 30, +) -> str: + """ + Search the Mobility Database for GTFS/GBFS/GTFS-RT feeds. + + Returns rich location and metadata context so the AI can disambiguate between results + (e.g., Montreal Quebec vs Montréal-du-Gers France). + + Args: + search_query: Free-text search (e.g., "Montreal", "Japan", "STM") + data_type: One of: gtfs, gtfs_rt, gbfs. Default: gtfs + status: Feed status filter: active, deprecated, inactive. Default: no filter + is_official: Filter for official feeds only + limit: Max results to return. Default: 30 + + Returns: + JSON string with query, total_matches, and results array with feed metadata + """ + db = Database() + with db.start_db_session() as session: + ts_query = get_parsed_search_tsquery(search_query) + rank_expr = func.ts_rank(t_feedsearch.c.document, ts_query).label("rank") + + query = select(rank_expr, *feed_search_columns) + + # Always filter to published feeds only (public MCP — no auth context) + query = query.filter(t_feedsearch.c.operational_status == "published") + + if data_type: + data_types = [dt.strip().lower() for dt in data_type.split(",")] + query = query.where(t_feedsearch.c.data_type.in_(data_types)) + + if status: + statuses = [s.strip().lower() for s in status.split(",")] + query = query.where(t_feedsearch.c.status.in_(statuses)) + + if is_official is not None: + if is_official: + query = query.where(t_feedsearch.c.official.is_(True)) + else: + query = query.where( + or_(t_feedsearch.c.official.is_(False), t_feedsearch.c.official.is_(None)) + ) + + if search_query and len(search_query.strip()) > 0: + query = query.filter(t_feedsearch.c.document.op("@@")(ts_query)) + + if search_query and len(search_query.strip()) > 0: + query = query.order_by(t_feedsearch.c.created_at.desc(), rank_expr.desc()) + else: + query = query.order_by(t_feedsearch.c.created_at.desc()) + + # Build parallel count query with same filters + count_query = select(func.count(t_feedsearch.c.feed_id)) + count_query = count_query.filter(t_feedsearch.c.operational_status == "published") + if data_type: + count_query = count_query.where(t_feedsearch.c.data_type.in_(data_types)) + if status: + count_query = count_query.where(t_feedsearch.c.status.in_(statuses)) + if is_official is not None: + if is_official: + count_query = count_query.where(t_feedsearch.c.official.is_(True)) + else: + count_query = count_query.where( + or_(t_feedsearch.c.official.is_(False), t_feedsearch.c.official.is_(None)) + ) + if search_query and len(search_query.strip()) > 0: + count_query = count_query.filter(t_feedsearch.c.document.op("@@")(ts_query)) + + rows = session.execute(query.limit(limit)).fetchall() + total_count_result = session.execute(count_query).fetchone() + total_count = total_count_result[0] if total_count_result else 0 + + results = [] + for row in rows: + row_dict = dict(row._mapping) + result = { + "feed_id": row_dict.get("feed_stable_id"), + "provider": row_dict.get("provider"), + "feed_name": row_dict.get("feed_name"), + "data_type": row_dict.get("data_type"), + "status": row_dict.get("status"), + "is_official": row_dict.get("official"), + "locations": row_dict.get("locations") or [], + "latest_dataset": { + "id": row_dict.get("latest_dataset_id"), + "hosted_url": row_dict.get("latest_dataset_hosted_url"), + "downloaded_at": str(row_dict.get("latest_dataset_downloaded_at")) + if row_dict.get("latest_dataset_downloaded_at") + else None, + "service_date_range_start": str(row_dict.get("latest_dataset_service_date_range_start")) + if row_dict.get("latest_dataset_service_date_range_start") + else None, + "service_date_range_end": str(row_dict.get("latest_dataset_service_date_range_end")) + if row_dict.get("latest_dataset_service_date_range_end") + else None, + }, + "validation_summary": { + "total_error": row_dict.get("latest_total_error"), + "total_warning": row_dict.get("latest_total_warning"), + "total_info": row_dict.get("latest_total_info"), + }, + "features": row_dict.get("latest_dataset_features") or [], + "search_rank": float(row_dict.get("rank", 0)), + } + results.append(result) + + return json.dumps( + { + "query": search_query, + "total_matches": total_count, + "results": results, + }, + default=str, + ) diff --git a/mcp/tests/__init__.py b/mcp/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/mcp/tests/test_config.json b/mcp/tests/test_config.json new file mode 100644 index 000000000..af348e24b --- /dev/null +++ b/mcp/tests/test_config.json @@ -0,0 +1,4 @@ +{ + "include_folders": [], + "include_api_folders": [] +} diff --git a/mcp/tests/test_get_validation_results.py b/mcp/tests/test_get_validation_results.py new file mode 100644 index 000000000..8397c53b6 --- /dev/null +++ b/mcp/tests/test_get_validation_results.py @@ -0,0 +1,283 @@ +""" +Unit tests for the get_validation_results MCP tool. +Tests use mocking to avoid requiring a live database or network. +""" +import json +import os +import sys +from datetime import datetime +from unittest.mock import MagicMock, patch + +import pytest + +mock_database_module = MagicMock() +mock_database_class = MagicMock() +mock_db_instance = MagicMock() +mock_database_class.return_value = mock_db_instance +mock_database_module.Database = mock_database_class + +mock_models_module = MagicMock() +mock_gtfsfeed = MagicMock() +mock_gtfsfeed.stable_id = MagicMock() +mock_gtfsfeed.latest_dataset = MagicMock() +mock_gtfsdataset = MagicMock() +mock_gtfsdataset.validation_reports = MagicMock() +mock_validationreport = MagicMock() +mock_validationreport.notices = MagicMock() +mock_validationreport.features = MagicMock() +mock_models_module.Gtfsfeed = mock_gtfsfeed +mock_models_module.Gtfsdataset = mock_gtfsdataset +mock_models_module.Validationreport = mock_validationreport + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "src")) + + +class MockJoinedLoad: + def joinedload(self, *_args, **_kwargs): + return self + + +class TestGetValidationResultsTool: + def setup_method(self): + self.mock_session = MagicMock() + self.mock_query = MagicMock() + self.mock_filtered_query = MagicMock() + self.mock_options_query = MagicMock() + self.mock_session.query.return_value = self.mock_query + self.mock_query.filter.return_value = self.mock_filtered_query + self.mock_filtered_query.options.return_value = self.mock_options_query + self.mock_options_query.first.return_value = None + mock_db_instance.start_db_session.return_value.__enter__ = MagicMock(return_value=self.mock_session) + mock_db_instance.start_db_session.return_value.__exit__ = MagicMock(return_value=False) + + self.mock_rules_cache = MagicMock() + self.mock_rules_cache.get_dict.return_value = None + self.mock_joinedload = MagicMock(return_value=MockJoinedLoad()) + + def _import_tool(self): + for module_name in ["tools.get_validation_results", "tools"]: + sys.modules.pop(module_name, None) + + with patch.dict( + sys.modules, + { + "shared": MagicMock(), + "shared.database": MagicMock(), + "shared.database.database": mock_database_module, + "shared.database_gen": MagicMock(), + "shared.database_gen.sqlacodegen_models": mock_models_module, + }, + ): + from tools import get_validation_results as module + return module + + def _call_tool(self, feed_id="mdb-1210", severity_filter="all"): + module = self._import_tool() + with patch.object(module, "get_rules_cache", return_value=self.mock_rules_cache), patch.object( + module, "joinedload", self.mock_joinedload + ): + return module.get_validation_results_tool(feed_id=feed_id, severity_filter=severity_filter) + + def _make_notice(self, code="missing_stop_name", severity="ERROR", total_notices=3): + notice = MagicMock() + notice.notice_code = code + notice.severity = severity + notice.total_notices = total_notices + return notice + + def _make_report(self, notices=None, features=None, validated_at=None): + report = MagicMock() + report.validated_at = validated_at or datetime(2024, 2, 12, 18, 1, 0) + report.validator_version = "7.1.0" + report.json_report = "https://example.com/report.json" + report.html_report = "https://example.com/report.html" + report.total_error = 2 + report.total_warning = 1 + report.total_info = 0 + report.unique_error_count = 1 + report.unique_warning_count = 1 + report.unique_info_count = 0 + report.notices = notices or [] + report.features = features or [] + return report + + def _make_feed(self, dataset=None, provider="Test Provider"): + feed = MagicMock() + feed.provider = provider + feed.latest_dataset = dataset + return feed + + def _make_dataset(self, stable_id="mdb-1210-202402121801", validation_reports=None): + dataset = MagicMock() + dataset.stable_id = stable_id + dataset.validation_reports = validation_reports or [] + return dataset + + def test_feed_not_found(self): + self.mock_options_query.first.return_value = None + result = json.loads(self._call_tool()) + assert result == {"error": "Feed 'mdb-1210' not found."} + + def test_no_dataset(self): + self.mock_options_query.first.return_value = self._make_feed(dataset=None) + result = json.loads(self._call_tool()) + assert result == {"error": "Feed 'mdb-1210' has no dataset yet."} + + def test_no_validation_report(self): + dataset = self._make_dataset(validation_reports=[]) + self.mock_options_query.first.return_value = self._make_feed(dataset=dataset) + result = json.loads(self._call_tool()) + assert result["feed_id"] == "mdb-1210" + assert result["provider"] == "Test Provider" + assert result["dataset_id"] == dataset.stable_id + assert result["error"] == "No validation report found for this feed." + + def test_returns_valid_json(self): + feature = MagicMock() + feature.name = "fares-v2" + notice = self._make_notice() + report = self._make_report(notices=[notice], features=[feature]) + dataset = self._make_dataset(validation_reports=[report]) + self.mock_options_query.first.return_value = self._make_feed(dataset=dataset, provider="Agency") + + result = json.loads(self._call_tool()) + assert result["feed_id"] == "mdb-1210" + assert result["provider"] == "Agency" + assert result["dataset_id"] == dataset.stable_id + assert result["validator_version"] == "7.1.0" + assert result["report_urls"]["json"] == "https://example.com/report.json" + assert result["gtfs_features"] == ["fares-v2"] + assert result["summary"]["total_errors"] == 2 + assert len(result["notices"]) == 1 + + def test_severity_filter_errors(self): + error_notice = self._make_notice(code="error_code", severity="ERROR") + warning_notice = self._make_notice(code="warning_code", severity="WARNING") + report = self._make_report(notices=[warning_notice, error_notice]) + dataset = self._make_dataset(validation_reports=[report]) + self.mock_options_query.first.return_value = self._make_feed(dataset=dataset) + + result = json.loads(self._call_tool(severity_filter="errors")) + assert [notice["code"] for notice in result["notices"]] == ["error_code"] + + def test_severity_filter_warnings(self): + error_notice = self._make_notice(code="error_code", severity="ERROR") + warning_notice = self._make_notice(code="warning_code", severity="WARNING") + report = self._make_report(notices=[error_notice, warning_notice]) + dataset = self._make_dataset(validation_reports=[report]) + self.mock_options_query.first.return_value = self._make_feed(dataset=dataset) + + result = json.loads(self._call_tool(severity_filter="warnings")) + assert [notice["code"] for notice in result["notices"]] == ["warning_code"] + + def test_rule_doc_enrichment(self): + self.mock_rules_cache.get_dict.return_value = { + "description": "Missing stop name", + "short_summary": "Stop name missing", + "affected_files": ["stops.txt"], + "rule_url": "https://example.com/rule", + } + notice = self._make_notice(code="missing_stop_name") + report = self._make_report(notices=[notice]) + dataset = self._make_dataset(validation_reports=[report]) + self.mock_options_query.first.return_value = self._make_feed(dataset=dataset) + + result = json.loads(self._call_tool()) + assert result["notices"][0]["rule_doc"]["description"] == "Missing stop name" + + def test_affected_file_fetched(self): + self.mock_rules_cache.get_dict.return_value = { + "description": "Missing stop name", + "short_summary": "Stop name missing", + "affected_files": ["stops.txt"], + "rule_url": "https://example.com/rule", + } + notice = self._make_notice(code="missing_stop_name") + report = self._make_report(notices=[notice]) + dataset = self._make_dataset(validation_reports=[report]) + self.mock_options_query.first.return_value = self._make_feed(dataset=dataset) + + module = self._import_tool() + response = MagicMock() + response.status_code = 200 + response.text = "stop_id,stop_name\n1,Main St\n2,Second St\n" + response.raise_for_status = MagicMock() + client = MagicMock() + client.get.return_value = response + client.__enter__.return_value = client + client.__exit__.return_value = False + + with patch.dict("os.environ", {"DATASETS_BUCKET_URL": "https://example.com"}), patch.object( + module, "get_rules_cache", return_value=self.mock_rules_cache + ), patch.object(module, "joinedload", self.mock_joinedload), patch.object(module.httpx, "Client", return_value=client): + result = json.loads(module.get_validation_results_tool(feed_id="mdb-1210")) + + affected_file = result["notices"][0]["affected_file"] + assert affected_file["filename"] == "stops.txt" + assert affected_file["columns"] == ["stop_id", "stop_name"] + assert affected_file["total_rows"] == 2 + client.get.assert_called_once_with( + "https://example.com/mdb-1210/mdb-1210-202402121801/extracted/stops.txt" + ) + + def test_affected_file_404(self): + self.mock_rules_cache.get_dict.return_value = { + "description": "Missing stop name", + "short_summary": "Stop name missing", + "affected_files": ["stops.txt"], + "rule_url": "https://example.com/rule", + } + notice = self._make_notice(code="missing_stop_name") + report = self._make_report(notices=[notice]) + dataset = self._make_dataset(validation_reports=[report]) + self.mock_options_query.first.return_value = self._make_feed(dataset=dataset) + + module = self._import_tool() + response = MagicMock() + response.status_code = 404 + response.raise_for_status = MagicMock() + client = MagicMock() + client.get.return_value = response + client.__enter__.return_value = client + client.__exit__.return_value = False + + with patch.dict("os.environ", {"DATASETS_BUCKET_URL": "https://example.com"}), patch.object( + module, "get_rules_cache", return_value=self.mock_rules_cache + ), patch.object(module, "joinedload", self.mock_joinedload), patch.object(module.httpx, "Client", return_value=client): + result = json.loads(module.get_validation_results_tool(feed_id="mdb-1210")) + + assert result["notices"][0]["affected_file"] is None + + def test_file_sample_deduplication(self): + def get_rule(_code): + return { + "description": "Shared file", + "short_summary": "Shared file", + "affected_files": ["stops.txt"], + "rule_url": "https://example.com/rule", + } + + self.mock_rules_cache.get_dict.side_effect = get_rule + first_notice = self._make_notice(code="notice_one", total_notices=5) + second_notice = self._make_notice(code="notice_two", total_notices=1) + report = self._make_report(notices=[first_notice, second_notice]) + dataset = self._make_dataset(validation_reports=[report]) + self.mock_options_query.first.return_value = self._make_feed(dataset=dataset) + + module = self._import_tool() + response = MagicMock() + response.status_code = 200 + response.text = "stop_id,stop_name\n1,Main St\n" + response.raise_for_status = MagicMock() + client = MagicMock() + client.get.return_value = response + client.__enter__.return_value = client + client.__exit__.return_value = False + + with patch.dict("os.environ", {"DATASETS_BUCKET_URL": "https://example.com"}), patch.object( + module, "get_rules_cache", return_value=self.mock_rules_cache + ), patch.object(module, "joinedload", self.mock_joinedload), patch.object(module.httpx, "Client", return_value=client): + result = json.loads(module.get_validation_results_tool(feed_id="mdb-1210")) + + assert len(result["notices"]) == 2 + client.get.assert_called_once() diff --git a/mcp/tests/test_query_gtfs.py b/mcp/tests/test_query_gtfs.py new file mode 100644 index 000000000..4a84c055a --- /dev/null +++ b/mcp/tests/test_query_gtfs.py @@ -0,0 +1,236 @@ +""" +Unit tests for the query_gtfs MCP tool. +Tests use mocking to avoid requiring a live database or network. +""" +import json +import os +import sys +from unittest.mock import MagicMock, patch + +import duckdb + +mock_database_module = MagicMock() +mock_database_class = MagicMock() +mock_db_instance = MagicMock() +mock_database_class.return_value = mock_db_instance +mock_database_module.Database = mock_database_class + +mock_models_module = MagicMock() +mock_gtfsfeed = MagicMock() +mock_gtfsfeed.stable_id = MagicMock() +mock_models_module.Gtfsfeed = mock_gtfsfeed + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "src")) + + +def _make_con(tables: dict[str, list[dict[str, str]]]) -> duckdb.DuckDBPyConnection: + """Build an in-memory DuckDB with tables from row dicts. + + Args: + tables: mapping of table_name -> list of row dicts + e.g. {"stops": [{"stop_id": "1", "stop_name": "Main St"}]} + """ + con = duckdb.connect() + for table_name, rows in tables.items(): + if not rows: + continue + columns = list(rows[0].keys()) + col_defs = ", ".join(f'"{c}" VARCHAR' for c in columns) + con.execute(f'CREATE TABLE "{table_name}" ({col_defs})') + placeholders = ", ".join(["?"] * len(columns)) + for row in rows: + con.execute( + f'INSERT INTO "{table_name}" VALUES ({placeholders})', + [row[c] for c in columns], + ) + return con + + +class TestQueryGtfsTool: + def setup_method(self): + self.mock_session = MagicMock() + self.mock_query = MagicMock() + self.mock_filtered_query = MagicMock() + self.mock_session.query.return_value = self.mock_query + self.mock_query.filter.return_value = self.mock_filtered_query + self.mock_filtered_query.first.return_value = None + mock_db_instance.start_db_session.return_value.__enter__ = MagicMock(return_value=self.mock_session) + mock_db_instance.start_db_session.return_value.__exit__ = MagicMock(return_value=False) + + def _import_tool(self): + for module_name in ["tools.query_gtfs", "tools", "gtfs_cache"]: + sys.modules.pop(module_name, None) + + with patch.dict( + sys.modules, + { + "shared": MagicMock(), + "shared.database": MagicMock(), + "shared.database.database": mock_database_module, + "shared.database_gen": MagicMock(), + "shared.database_gen.sqlacodegen_models": mock_models_module, + }, + ): + from tools import query_gtfs as module + return module + + def _make_feed(self, dataset=None): + feed = MagicMock() + feed.latest_dataset = dataset + return feed + + def _make_dataset(self, stable_id="mdb-1210-202402121801"): + dataset = MagicMock() + dataset.stable_id = stable_id + return dataset + + def _call_tool(self, module, query="SCHEMA", feed_id="mdb-1210", bucket_url="https://example.com", con=None, files=None): + original = os.environ.get("DATASETS_BUCKET_URL") + if bucket_url: + os.environ["DATASETS_BUCKET_URL"] = bucket_url + elif "DATASETS_BUCKET_URL" in os.environ: + del os.environ["DATASETS_BUCKET_URL"] + try: + if con is not None: + with patch.object(module, "_load_duckdb", return_value=con): + return module.query_gtfs_tool(feed_id=feed_id, query=query, files=files) + return module.query_gtfs_tool(feed_id=feed_id, query=query, files=files) + finally: + if original is not None: + os.environ["DATASETS_BUCKET_URL"] = original + elif "DATASETS_BUCKET_URL" in os.environ: + del os.environ["DATASETS_BUCKET_URL"] + + def test_feed_not_found(self): + module = self._import_tool() + self.mock_filtered_query.first.return_value = None + result = json.loads(self._call_tool(module)) + assert result == {"error": "Feed 'mdb-1210' not found."} + + def test_no_dataset(self): + module = self._import_tool() + self.mock_filtered_query.first.return_value = self._make_feed(dataset=None) + result = json.loads(self._call_tool(module)) + assert result == {"error": "Feed 'mdb-1210' has no dataset yet."} + + def test_schema_mode(self): + module = self._import_tool() + self.mock_filtered_query.first.return_value = self._make_feed(dataset=self._make_dataset()) + con = _make_con({ + "stops": [{"stop_id": "1", "stop_name": "Main St"}, {"stop_id": "2", "stop_name": "Second St"}], + "routes": [{"route_id": "10", "route_short_name": "10"}], + }) + result = json.loads(self._call_tool(module, query="SCHEMA", con=con)) + assert result["feed_id"] == "mdb-1210" + assert result["dataset_id"] == "mdb-1210-202402121801" + assert result["tables"]["stops"]["columns"] == ["stop_id", "stop_name"] + assert result["tables"]["stops"]["row_count"] == 2 + assert result["tables"]["routes"]["row_count"] == 1 + assert result["available_files"] == ["routes.txt", "stops.txt"] + + def test_schema_mode_case_insensitive(self): + module = self._import_tool() + self.mock_filtered_query.first.return_value = self._make_feed(dataset=self._make_dataset()) + con = _make_con({"stops": [{"stop_id": "1", "stop_name": "Main St"}]}) + result = json.loads(self._call_tool(module, query="schema", con=con)) + assert result["tables"]["stops"]["columns"] == ["stop_id", "stop_name"] + + def test_select_query_returns_results(self): + module = self._import_tool() + self.mock_filtered_query.first.return_value = self._make_feed(dataset=self._make_dataset()) + con = _make_con({"stops": [{"stop_id": "1", "stop_name": "Main St"}, {"stop_id": "2", "stop_name": "Second St"}]}) + result = json.loads( + self._call_tool( + module, + query="SELECT stop_id, stop_name FROM stops ORDER BY stop_id", + con=con, + ) + ) + assert result["columns"] == ["stop_id", "stop_name"] + assert result["rows"] == [["1", "Main St"], ["2", "Second St"]] + assert result["row_count"] == 2 + assert isinstance(result["execution_time_ms"], int) + assert result["execution_time_ms"] >= 0 + assert "LIMIT 1000" in result["sql"] + + def test_non_select_rejected(self): + module = self._import_tool() + self.mock_filtered_query.first.return_value = self._make_feed(dataset=self._make_dataset()) + con = _make_con({}) + result = json.loads(self._call_tool(module, query="DELETE FROM stops", con=con)) + assert "SELECT" in result["error"] + + def test_schema_only_shows_loaded_tables(self): + module = self._import_tool() + self.mock_filtered_query.first.return_value = self._make_feed(dataset=self._make_dataset()) + con = _make_con({"stops": [{"stop_id": "1", "stop_name": "Main St"}]}) + result = json.loads(self._call_tool(module, query="SCHEMA", con=con)) + assert list(result["tables"].keys()) == ["stops"] + assert result["available_files"] == ["stops.txt"] + + def test_cache_hit(self): + module = self._import_tool() + self.mock_filtered_query.first.return_value = self._make_feed(dataset=self._make_dataset()) + con = duckdb.connect() + con.execute("CREATE TABLE stops (stop_id VARCHAR)") + con.execute("INSERT INTO stops VALUES ('1')") + + with patch.dict("os.environ", {"DATASETS_BUCKET_URL": "https://example.com"}), patch.object( + module, "_load_duckdb", return_value=con + ) as mock_loader: + first = json.loads(module.query_gtfs_tool(feed_id="mdb-1210", query="SELECT * FROM stops")) + second = json.loads(module.query_gtfs_tool(feed_id="mdb-1210", query="SELECT * FROM stops")) + + assert first["rows"] == [["1"]] + assert second["rows"] == [["1"]] + assert mock_loader.call_count == 1 + + def test_datasets_bucket_url_missing(self): + module = self._import_tool() + self.mock_filtered_query.first.return_value = self._make_feed(dataset=self._make_dataset()) + result = json.loads(self._call_tool(module, bucket_url="")) + assert result == {"error": "DATASETS_BUCKET_URL is not configured."} + + def test_files_parameter_limits_loaded_tables(self): + module = self._import_tool() + self.mock_filtered_query.first.return_value = self._make_feed(dataset=self._make_dataset()) + con = _make_con({ + "stops": [{"stop_id": "1", "stop_name": "Main St"}], + "routes": [{"route_id": "10", "route_short_name": "10"}], + }) + result = json.loads(self._call_tool(module, query="SCHEMA", con=con, files=["stops", "routes"])) + assert "stops" in result["tables"] + assert "routes" in result["tables"] + assert "trips" not in result["tables"] + + def test_files_parameter_accepts_txt_suffix(self): + module = self._import_tool() + self.mock_filtered_query.first.return_value = self._make_feed(dataset=self._make_dataset()) + con = _make_con({"agency": [{"agency_id": "A1", "agency_name": "Transit Co"}]}) + result = json.loads(self._call_tool(module, query="SCHEMA", con=con, files=["agency.txt"])) + assert "agency" in result["tables"] + + def test_query_infers_tables_when_files_not_provided(self): + module = self._import_tool() + self.mock_filtered_query.first.return_value = self._make_feed(dataset=self._make_dataset()) + con = _make_con({"routes": [{"route_id": "10", "route_short_name": "10"}]}) + with patch.object(module, "_load_duckdb", return_value=con) as spy: + result = json.loads( + self._call_tool( + module, + query="SELECT * FROM routes", + con=None, + files=None, + ) + ) + loaded_files = spy.call_args[0][3] + assert loaded_files == ["routes.txt"] + assert result["columns"] == ["route_id", "route_short_name"] + + def test_files_parameter_invalid_names_return_error(self): + module = self._import_tool() + self.mock_filtered_query.first.return_value = self._make_feed(dataset=self._make_dataset()) + result = json.loads(self._call_tool(module, query="SCHEMA", files=["stops", "bogus"])) + assert "error" in result + assert "bogus" in result["error"] + assert "Valid files are" in result["error"] diff --git a/mcp/tests/test_search_feeds.py b/mcp/tests/test_search_feeds.py new file mode 100644 index 000000000..e9f3ba6dc --- /dev/null +++ b/mcp/tests/test_search_feeds.py @@ -0,0 +1,246 @@ +""" +Unit tests for the search_feeds MCP tool. +Tests use mocking to avoid requiring a live database. +""" +import json +import sys +import os +from unittest.mock import MagicMock, patch, PropertyMock +from datetime import datetime + +import pytest + +# We need to mock shared modules before importing the tool since the +# shared/ symlinks may not exist in the test environment. +# Create mock modules for all shared dependencies. +mock_database_module = MagicMock() +mock_database_class = MagicMock() +mock_db_instance = MagicMock() +mock_database_class.return_value = mock_db_instance +mock_database_module.Database = mock_database_class + +mock_unaccent_module = MagicMock() +mock_unaccent = MagicMock(side_effect=lambda x: x) +mock_unaccent_module.unaccent = mock_unaccent + +mock_models_module = MagicMock() + +# Patch sys.modules before importing anything from the tools package +sys.modules['shared'] = MagicMock() +sys.modules['shared.database'] = MagicMock() +sys.modules['shared.database.database'] = mock_database_module +sys.modules['shared.database.sql_functions'] = MagicMock() +sys.modules['shared.database.sql_functions.unaccent'] = mock_unaccent_module +sys.modules['shared.database_gen'] = MagicMock() +sys.modules['shared.database_gen.sqlacodegen_models'] = mock_models_module + +# Add mcp/src to path so we can import tools +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src')) + + +def _make_mock_col(name): + """Create a MagicMock that properly exposes .name as an attribute.""" + col = MagicMock() + col.name = name + return col + + +_UNSET = object() + + +def make_feed_row( + feed_stable_id="mdb-1", + provider="Test Provider", + feed_name="Test Feed", + data_type="gtfs", + status="active", + official=True, + locations=_UNSET, + latest_dataset_id="mdb-1-202401010000", + latest_dataset_hosted_url="https://example.com/feed.zip", + latest_dataset_downloaded_at=None, + latest_dataset_service_date_range_start=None, + latest_dataset_service_date_range_end=None, + latest_total_error=0, + latest_total_warning=2, + latest_total_info=1, + latest_dataset_features=_UNSET, + rank=0.75, +): + """Create a mock database row.""" + row = MagicMock() + row._mapping = { + "feed_stable_id": feed_stable_id, + "provider": provider, + "feed_name": feed_name, + "data_type": data_type, + "status": status, + "official": official, + "locations": locations if locations is not _UNSET else [{"country_code": "CA", "country": "Canada", "subdivision_name": "Quebec", "municipality": "Montreal"}], + "latest_dataset_id": latest_dataset_id, + "latest_dataset_hosted_url": latest_dataset_hosted_url, + "latest_dataset_downloaded_at": latest_dataset_downloaded_at, + "latest_dataset_service_date_range_start": latest_dataset_service_date_range_start, + "latest_dataset_service_date_range_end": latest_dataset_service_date_range_end, + "latest_total_error": latest_total_error, + "latest_total_warning": latest_total_warning, + "latest_total_info": latest_total_info, + "latest_dataset_features": latest_dataset_features if latest_dataset_features is not _UNSET else ["Shapes", "Headsigns"], + "rank": rank, + } + return row + + +class TestSearchFeedsTool: + """Tests for the search_feeds_tool function.""" + + def setup_method(self): + """Set up mocks for each test.""" + mock_models_module.t_feedsearch = MagicMock() + mock_models_module.t_feedsearch.columns = [ + _make_mock_col("feed_stable_id"), + _make_mock_col("provider"), + _make_mock_col("feed_name"), + _make_mock_col("data_type"), + _make_mock_col("status"), + _make_mock_col("official"), + _make_mock_col("locations"), + _make_mock_col("document"), # excluded from feed_search_columns + ] + + # Set up context manager for db session + self.mock_session = MagicMock() + mock_db_instance.start_db_session.return_value.__enter__ = MagicMock(return_value=self.mock_session) + mock_db_instance.start_db_session.return_value.__exit__ = MagicMock(return_value=False) + + def _call_tool(self, search_query="Montreal", **kwargs): + """Import and call the tool fresh (needed due to module-level setup). + + SQLAlchemy is mocked so that select()/func/or_ don't validate column + types — the session.execute() is already mocked, so we only need the + query-builder calls to return chainable MagicMocks. + """ + if 'tools.search_feeds' in sys.modules: + del sys.modules['tools.search_feeds'] + if 'tools' in sys.modules: + del sys.modules['tools'] + + mock_sa = MagicMock() + with patch.dict(sys.modules, {'sqlalchemy': mock_sa}): + from tools.search_feeds import search_feeds_tool + return search_feeds_tool(search_query=search_query, **kwargs) + + def test_returns_valid_json(self): + """Tool always returns valid JSON.""" + self.mock_session.execute.return_value.fetchall.return_value = [] + self.mock_session.execute.return_value.fetchone.return_value = (0,) + result = self._call_tool("Montreal") + parsed = json.loads(result) + assert "query" in parsed + assert "total_matches" in parsed + assert "results" in parsed + + def test_query_echoed_in_output(self): + """The search_query is reflected in the output.""" + self.mock_session.execute.return_value.fetchall.return_value = [] + self.mock_session.execute.return_value.fetchone.return_value = (0,) + result = json.loads(self._call_tool("Montreal")) + assert result["query"] == "Montreal" + + def test_empty_results(self): + """Returns empty results list when no feeds match.""" + self.mock_session.execute.return_value.fetchall.return_value = [] + self.mock_session.execute.return_value.fetchone.return_value = (0,) + result = json.loads(self._call_tool("xyznonexistent")) + assert result["total_matches"] == 0 + assert result["results"] == [] + + def test_result_schema(self): + """Each result has the expected fields.""" + row = make_feed_row() + self.mock_session.execute.return_value.fetchall.return_value = [row] + self.mock_session.execute.return_value.fetchone.return_value = (1,) + result = json.loads(self._call_tool("Montreal")) + assert len(result["results"]) == 1 + feed = result["results"][0] + assert "feed_id" in feed + assert "provider" in feed + assert "feed_name" in feed + assert "data_type" in feed + assert "status" in feed + assert "is_official" in feed + assert "locations" in feed + assert "latest_dataset" in feed + assert "validation_summary" in feed + assert "features" in feed + assert "search_rank" in feed + + def test_result_values_correct(self): + """Result values are correctly mapped from DB row.""" + row = make_feed_row( + feed_stable_id="mdb-956", + provider="STM", + data_type="gtfs", + status="active", + official=True, + rank=0.95, + ) + self.mock_session.execute.return_value.fetchall.return_value = [row] + self.mock_session.execute.return_value.fetchone.return_value = (1,) + result = json.loads(self._call_tool("STM")) + feed = result["results"][0] + assert feed["feed_id"] == "mdb-956" + assert feed["provider"] == "STM" + assert feed["data_type"] == "gtfs" + assert feed["status"] == "active" + assert feed["is_official"] is True + assert abs(feed["search_rank"] - 0.95) < 0.01 + + def test_validation_summary_in_result(self): + """Validation summary fields are present.""" + row = make_feed_row(latest_total_error=3, latest_total_warning=7, latest_total_info=2) + self.mock_session.execute.return_value.fetchall.return_value = [row] + self.mock_session.execute.return_value.fetchone.return_value = (1,) + result = json.loads(self._call_tool("test")) + feed = result["results"][0] + assert feed["validation_summary"]["total_error"] == 3 + assert feed["validation_summary"]["total_warning"] == 7 + assert feed["validation_summary"]["total_info"] == 2 + + def test_locations_in_result(self): + """Locations array is included in result for disambiguation.""" + locations = [{"country_code": "CA", "country": "Canada", "subdivision_name": "Quebec", "municipality": "Montreal"}] + row = make_feed_row(locations=locations) + self.mock_session.execute.return_value.fetchall.return_value = [row] + self.mock_session.execute.return_value.fetchone.return_value = (1,) + result = json.loads(self._call_tool("Montreal")) + feed = result["results"][0] + assert feed["locations"] == locations + + def test_empty_query_returns_all_feeds(self): + """Empty search_query returns all feeds without text filter.""" + self.mock_session.execute.return_value.fetchall.return_value = [] + self.mock_session.execute.return_value.fetchone.return_value = (5,) + result = json.loads(self._call_tool("")) + assert result["query"] == "" + assert result["total_matches"] == 5 + + def test_none_fields_handled_gracefully(self): + """None values in DB row are handled without exceptions.""" + row = make_feed_row( + feed_name=None, + latest_dataset_id=None, + latest_dataset_hosted_url=None, + latest_dataset_downloaded_at=None, + latest_total_error=None, + latest_dataset_features=None, + ) + row._mapping["locations"] = None + self.mock_session.execute.return_value.fetchall.return_value = [row] + self.mock_session.execute.return_value.fetchone.return_value = (1,) + # Should not raise + result = json.loads(self._call_tool("test")) + feed = result["results"][0] + assert feed["feed_name"] is None + assert feed["locations"] == [] + assert feed["features"] == [] diff --git a/scripts/docker-build-push.sh b/scripts/docker-build-push.sh index a3a826a4a..84d27735c 100755 --- a/scripts/docker-build-push.sh +++ b/scripts/docker-build-push.sh @@ -41,6 +41,8 @@ display_usage() { echo " -repo_name The GCP Artifactory repository name." echo " -service The cloud run service name." echo " -version The container's image version to push." + echo " -dockerfile Path to the Dockerfile (default: api/Dockerfile)." + echo " -context Docker build context directory (default: api/)." exit 1 } @@ -48,6 +50,8 @@ PROJECT_ID="" SERVICE="" REGION="" VERSION="" +DOCKERFILE="" +CONTEXT="" while [[ $# -gt 0 ]]; do key="$1" @@ -78,6 +82,16 @@ while [[ $# -gt 0 ]]; do shift # past argument shift # past value ;; + -dockerfile) + DOCKERFILE="$2" + shift # past argument + shift # past value + ;; + -context) + CONTEXT="$2" + shift # past argument + shift # past value + ;; -h|--help) display_usage ;; @@ -96,7 +110,14 @@ fi # relative path SCRIPT_PATH="$(dirname -- "${BASH_SOURCE[0]}")" +# Default build context is api/, default Dockerfile is inside the context +BUILD_CONTEXT="${CONTEXT:-$SCRIPT_PATH/../api}" +DOCKERFILE_ARG="" +if [[ -n "${DOCKERFILE}" ]]; then + DOCKERFILE_ARG="-f ${DOCKERFILE}" +fi + DOCKER_TAG=${REGION}-docker.pkg.dev/${PROJECT_ID}/${REPO_NAME}/${SERVICE}:${VERSION} # Build your Docker image -docker buildx build --push --platform linux/amd64 --no-cache -t $DOCKER_TAG $SCRIPT_PATH/../api +docker buildx build --push --platform linux/amd64 --no-cache $DOCKERFILE_ARG -t $DOCKER_TAG $BUILD_CONTEXT diff --git a/scripts/function-python-setup.sh b/scripts/function-python-setup.sh index 6e94ab2b2..8257d3ef6 100755 --- a/scripts/function-python-setup.sh +++ b/scripts/function-python-setup.sh @@ -43,7 +43,8 @@ display_usage() { echo "Options:" echo " -h|--help Display help content." echo " --function_name Name of the function to be setup." - echo " --all Setup all functions." + echo " --all Setup all functions (includes MCP)." + echo " --mcp Setup MCP server shared folders only." echo " --clean Clean shared folders." exit 1 } @@ -51,6 +52,7 @@ display_usage() { FX_NAME_PARAM='' ALL='false' CLEAN='false' +MCP='false' while [[ $# -gt 0 ]]; do key="$1" @@ -72,6 +74,10 @@ while [[ $# -gt 0 ]]; do CLEAN="true" shift ;; + --mcp) + MCP="true" + shift + ;; *) # unknown option shift # past argument ;; @@ -174,6 +180,31 @@ clean_shared_folders() { rmdir "$FUNCTIONS_PATH/$function_name/src/test_shared" > /dev/null 2>&1 } +setup_mcp() { + MCP_PATH="$ROOT_PATH/mcp" + MCP_SOURCE_PATH="$MCP_PATH/src" + MCP_CONFIG_FILE="$MCP_PATH/mcp_config.json" + + if [ ! -f "$MCP_CONFIG_FILE" ]; then + echo "INFO: No mcp_config.json found at $MCP_CONFIG_FILE, skipping MCP setup." + return + fi + + echo "Setting up MCP server shared folders" + include_api_folders=$(jq -r '.include_api_folders[]' "$MCP_CONFIG_FILE" 2>/dev/null) + + dst_folder="$MCP_SOURCE_PATH/shared" + rm -rf "$dst_folder" + mkdir -p "$dst_folder" + create_symbolic_links "$API_PATH" "$include_api_folders" "$dst_folder" +} + +clean_mcp() { + echo "INFO: Cleaning MCP shared folders" + rm -f "$ROOT_PATH/mcp/src/shared/"* + rmdir "$ROOT_PATH/mcp/src/shared" 2>/dev/null || true +} + if [ "$ALL" = "true" ]; then # get all the functions in the functions-python folder that contain a function_config.json file for function in $(find "$FUNCTIONS_PATH" -maxdepth 2 -name "function_config.json"); do @@ -184,6 +215,18 @@ if [ "$ALL" = "true" ]; then setup_function $function_name fi done + # Also process MCP when --all is used + if [ "$CLEAN" = "true" ]; then + clean_mcp + else + setup_mcp + fi +elif [ "$MCP" = "true" ]; then + if [ "$CLEAN" = "true" ]; then + clean_mcp + else + setup_mcp + fi else if [ -z "$FX_NAME_PARAM" ]; then printf "\nERROR: function name not provided" From 86eb0e62210cf1b5d1dbbb24039cc7246447b7ab Mon Sep 17 00:00:00 2001 From: cka-y Date: Thu, 14 May 2026 14:35:29 -0400 Subject: [PATCH 2/5] fix: req.txt context --- mcp/Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mcp/Dockerfile b/mcp/Dockerfile index 351181286..df2821537 100644 --- a/mcp/Dockerfile +++ b/mcp/Dockerfile @@ -3,7 +3,7 @@ FROM python:3.11-slim WORKDIR /app # Copy requirements first for layer caching -COPY requirements.txt . +COPY mcp/requirements.txt . RUN pip install --no-cache-dir -r requirements.txt # Copy shared code from api (symlinks won't work in Docker, copy directly) From 37ffe758aacbe888da20ab74611651fad4148f70 Mon Sep 17 00:00:00 2001 From: cka-y Date: Tue, 19 May 2026 13:27:34 -0400 Subject: [PATCH 3/5] fix: duckdb extension --- mcp/Dockerfile | 3 +++ mcp/src/tools/query_gtfs.py | 1 - 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/mcp/Dockerfile b/mcp/Dockerfile index df2821537..e1228c12c 100644 --- a/mcp/Dockerfile +++ b/mcp/Dockerfile @@ -15,6 +15,9 @@ COPY mcp/src/ /app/ # Set Python path so imports work ENV PYTHONPATH=/app +# Pre-install DuckDB extensions so they're available offline at runtime +RUN python -c "import duckdb; con = duckdb.connect(); con.install_extension('httpfs'); con.close()" + EXPOSE 8080 CMD ["python", "main.py"] diff --git a/mcp/src/tools/query_gtfs.py b/mcp/src/tools/query_gtfs.py index fbffb0040..7b2b45f70 100644 --- a/mcp/src/tools/query_gtfs.py +++ b/mcp/src/tools/query_gtfs.py @@ -52,7 +52,6 @@ def _file_name_for_table(table_name: str) -> str: def _load_duckdb(feed_id: str, dataset_id: str, datasets_bucket_url: str, files: list[str]) -> duckdb.DuckDBPyConnection: """Load GTFS files directly from GCS into an in-memory DuckDB via httpfs.""" con = duckdb.connect() - con.install_extension("httpfs") con.load_extension("httpfs") base_url = f"{datasets_bucket_url}/{feed_id}/{dataset_id}/extracted" From aabb5adaa07285fba5ecfe0bea994d73586e311b Mon Sep 17 00:00:00 2001 From: cka-y Date: Tue, 19 May 2026 14:01:50 -0400 Subject: [PATCH 4/5] fix: added failure propagation from duckdb --- mcp/README.md | 87 +++++++--------------------------- mcp/src/gtfs_cache.py | 15 +++--- mcp/src/tools/query_gtfs.py | 72 +++++++++++++++++++--------- mcp/src/tools/search_feeds.py | 8 ---- mcp/tests/test_query_gtfs.py | 39 ++++++++------- mcp/tests/test_search_feeds.py | 2 - 6 files changed, 98 insertions(+), 125 deletions(-) diff --git a/mcp/README.md b/mcp/README.md index db170eff4..f43c778a3 100644 --- a/mcp/README.md +++ b/mcp/README.md @@ -12,7 +12,6 @@ Searches the Mobility Database using PostgreSQL full-text search against the `Fe |---|---|---|---| | `search_query` | string | — | Free-text search (e.g. `"Montreal"`, `"STM"`, `"Japan"`) | | `data_type` | string | `gtfs` | `gtfs`, `gtfs_rt`, or `gbfs` | -| `status` | string | none | `active`, `inactive`, or `deprecated` | | `is_official` | boolean | none | Filter to official feeds only | | `limit` | integer | `30` | Max results | @@ -38,8 +37,9 @@ Loads a feed's extracted GTFS files into an in-memory DuckDB database and execut |---|---|---|---| | `feed_id` | string | — | Mobility Database feed ID (e.g. `mdb-1210`) | | `query` | string | — | `SCHEMA` to list tables/columns, or any SQL `SELECT` statement | +| `files` | list[string] | — | GTFS files to load (e.g. `["stops", "routes"]`). **Required for SELECT queries.** Omit for SCHEMA queries to discover all available tables. | -> Use `SCHEMA` first to discover available tables, then write SQL to answer your question. Only `SELECT` queries are allowed. +> Use `SCHEMA` first to discover available tables, then write SQL to answer your question. Only `SELECT` queries are allowed. If some GTFS files are unavailable (e.g. not extracted), they are reported in `failed_files` rather than silently skipped. ## Running locally @@ -87,6 +87,21 @@ cd mcp python -m pytest tests/ -v ``` +### Testing with the MCP Inspector + +Start the server locally, then in a separate terminal launch the [MCP Inspector](https://github.com/modelcontextprotocol/inspector) to interactively test tools: + +```bash +# 1. Start the server (SSE mode, default) +cd mcp/src +PYTHONPATH=. python main.py + +# 2. In another terminal, launch the inspector +DANGEROUSLY_OMIT_AUTH=true npx @modelcontextprotocol/inspector +``` + +Open the URL printed by the inspector, enter `http://localhost:8080/sse` as the server URL, and you can list and call tools directly and test them like any API. This is a great way to iterate on tool outputs and debug without needing to connect a full MCP client. + ## Running with Docker The Dockerfile build context is the **repo root** (needed to copy shared code): @@ -135,73 +150,7 @@ Restart Claude Desktop. You should see the `search_feeds`, `get_validation_resul > *"What validation errors does mdb-1210 have?"* > *"Show me the errors in the STM feed"* -## Questions that feel like magic - -### Single-tool deep dives - -Once Claude can query raw GTFS tables directly, you can ask things that would be painful to answer by hand: - -> *"Which route in the STM network serves the most unique stops?"* -> *"Are there any trips that depart after midnight? List them with their headsigns."* -> *"What percentage of stops in the Tokyo feed have wheelchair boarding info?"* -> *"Which agency in the feed operates the most distinct routes?"* -> *"Find all stops within 500m of each other that belong to different routes — potential transfer point opportunities."* -> *"How many trips run on weekdays vs weekends for each route?"* -> *"What's the average dwell time between consecutive stops on the busiest route?"* -> *"Show me routes that have no shapes defined — they'll render as straight lines on a map."* -> *"Which stops appear in stop_times but are missing from stops.txt?"* -> *"What's the earliest first departure and latest last departure across all routes?"* - -These are the kinds of questions that become conversational once an AI can inspect `trips.txt`, `stop_times.txt`, `routes.txt`, `calendar.txt`, `shapes.txt`, and the rest through SQL. - -### Cross-tool investigations - -The real power emerges when the tools chain together. Claude will figure out the sequence — you just ask the question. - -> *"Find all active GTFS feeds in California, then tell me which ones have validation errors and what those errors are."* -> → `search_feeds` (California) → `get_validation_results` for each → summary report - -> *"I'm traveling in Tokyo next month — which feed covers the Tokyo Metro, does it have any errors that could affect trip planning, and what routes serve Shinjuku station?"* -> → `search_feeds` (Tokyo Metro) → `get_validation_results` → `query_gtfs` (stops near Shinjuku) - -> *"Compare the data quality of the top 5 transit agencies in Canada — who has the cleanest feed?"* -> → `search_feeds` (Canada, limit=5) → `get_validation_results` for each → ranked comparison by error count - -> *"The STM feed has a `stop_times_with_only_arrival_time` warning — can you show me which trips are affected and what the schedule looks like for those stops?"* -> → `get_validation_results` (mdb-956, warnings) → `query_gtfs` (SELECT from stop_times WHERE arrival_time IS NOT NULL AND departure_time IS NULL) - -> *"Are there any official feeds in Europe with zero validation errors? If so, what GTFS features do they support?"* -> → `search_feeds` (Europe, is_official=true) → `get_validation_results` for each → filter zero errors → list features - -> *"Find the feed for the Paris Métro, check if it has wheelchair accessibility data, and count what fraction of stops actually have it filled in."* -> → `search_feeds` (Paris Métro) → `get_validation_results` (check features list) → `query_gtfs` (SELECT COUNT(*) by wheelchair_boarding value) - -> *"Which city in Japan has the most comprehensive GTFS feed — most routes, most stops, fewest errors?"* -> → `search_feeds` (Japan) → `get_validation_results` for each → `query_gtfs` SCHEMA on top candidates → compare route/stop counts - -## Architecture - -``` -Claude Desktop (or any MCP client) - │ MCP Protocol (stdio locally, SSE when deployed) - ▼ - MCP Server (Python) - ├──► SQLAlchemy → PostgreSQL (feed metadata, search, validation reports) - └──► DuckDB (in-memory) ← GTFS CSVs fetched from GCS public URLs -``` - -The server connects **directly to the database** — it does not call the public Feed API. It reuses the shared database models and query logic from `api/src/shared/` (linked via symlinks in `src/shared/`). ## Deployment -Terraform infrastructure is in `infra/mcp/`. The module creates a Cloud Run service (`mcp-server-{env}`) and is wired into the root `infra/main.tf`. Deploy by building and pushing the Docker image to Artifact Registry, then running terraform: - -```bash -# Build and push (from repo root) -docker build -f mcp/Dockerfile -t {region}-docker.pkg.dev/{project}/{repo}/mcp-server:{version} . -docker push {region}-docker.pkg.dev/{project}/{repo}/mcp-server:{version} - -# Apply terraform -cd infra -terraform apply -var="mcp_image_version={version}" -``` +Terraform infrastructure is in `infra/mcp/`. The module creates a Cloud Run service (`mcp-server-{env}`) and is wired into the root `infra/main.tf`. Deploy by building and pushing the Docker image to Artifact Registry, then running terraform deployment. diff --git a/mcp/src/gtfs_cache.py b/mcp/src/gtfs_cache.py index d5f10fbc4..379eb91be 100644 --- a/mcp/src/gtfs_cache.py +++ b/mcp/src/gtfs_cache.py @@ -9,10 +9,11 @@ class _CacheEntry: - __slots__ = ("connection", "loaded_at", "lock") + __slots__ = ("connection", "failed_files", "loaded_at", "lock") def __init__(self): self.connection: Optional[duckdb.DuckDBPyConnection] = None + self.failed_files: list[str] = [] self.loaded_at: float = 0.0 self.lock = threading.Lock() @@ -43,12 +44,12 @@ def get_or_load( self, feed_id: str, dataset_id: str, - loader_fn: Callable[[], duckdb.DuckDBPyConnection], - ) -> duckdb.DuckDBPyConnection: + loader_fn: Callable[[], tuple[duckdb.DuckDBPyConnection, list[str]]], + ) -> tuple[duckdb.DuckDBPyConnection, list[str]]: key = (feed_id, dataset_id) entry = self._cache.get(key) if entry and self._is_fresh(entry): - return entry.connection + return entry.connection, entry.failed_files if entry is None: with self._cache_lock: @@ -60,9 +61,9 @@ def get_or_load( old_connection: Optional[duckdb.DuckDBPyConnection] = None with entry.lock: if self._is_fresh(entry): - return entry.connection + return entry.connection, entry.failed_files old_connection = entry.connection - entry.connection = loader_fn() + entry.connection, entry.failed_files = loader_fn() entry.loaded_at = time.monotonic() connection = entry.connection @@ -72,7 +73,7 @@ def get_or_load( except Exception: pass - return connection + return connection, entry.failed_files _gtfs_cache: Optional[GtfsCache] = None diff --git a/mcp/src/tools/query_gtfs.py b/mcp/src/tools/query_gtfs.py index 7b2b45f70..901f5e0e9 100644 --- a/mcp/src/tools/query_gtfs.py +++ b/mcp/src/tools/query_gtfs.py @@ -49,11 +49,17 @@ def _file_name_for_table(table_name: str) -> str: STANDARD_TABLE_NAMES = frozenset(_table_name_for_file(f) for f in STANDARD_GTFS_FILES) -def _load_duckdb(feed_id: str, dataset_id: str, datasets_bucket_url: str, files: list[str]) -> duckdb.DuckDBPyConnection: - """Load GTFS files directly from GCS into an in-memory DuckDB via httpfs.""" +def _load_duckdb( + feed_id: str, dataset_id: str, datasets_bucket_url: str, files: list[str] +) -> tuple[duckdb.DuckDBPyConnection, list[str]]: + """Load GTFS files directly from GCS into an in-memory DuckDB via httpfs. + + Returns (connection, list_of_failed_filenames). + """ con = duckdb.connect() con.load_extension("httpfs") + failed_files: list[str] = [] base_url = f"{datasets_bucket_url}/{feed_id}/{dataset_id}/extracted" for filename in files: table_name = _table_name_for_file(filename) @@ -65,8 +71,9 @@ def _load_duckdb(feed_id: str, dataset_id: str, datasets_bucket_url: str, files: ) except Exception as exc: logger.warning("Failed to load %s: %s", filename, exc) + failed_files.append(filename) - return con + return con, failed_files def _resolve_dataset(feed_id: str) -> tuple[str | None, str | None]: @@ -115,12 +122,18 @@ def _validate_files( Returns (filenames, error). When *error* is not None the caller should return it directly — *filenames* will be empty. """ - user_provided = bool(files) - if not files and query and query.strip().upper() != "SCHEMA": - files = _extract_tables_from_query(query) + is_schema = (query or "").strip().upper() == "SCHEMA" if not files: - return list(STANDARD_GTFS_FILES), None + if is_schema: + return list(STANDARD_GTFS_FILES), None + valid_list = ", ".join(sorted(STANDARD_TABLE_NAMES)) + return [], json.dumps({ + "error": "The 'files' parameter is required for SELECT queries. " + "Specify which GTFS files to load (e.g. [\"stops\", \"routes\"]). " + "Use query=\"SCHEMA\" first to discover available tables. " + f"Valid files are: {valid_list}", + }) normalized = [] invalid = [] @@ -134,17 +147,20 @@ def _validate_files( continue normalized.append(_file_name_for_table(table_name)) - if invalid and user_provided: + if invalid: valid_list = ", ".join(sorted(STANDARD_TABLE_NAMES)) return [], json.dumps({ "error": f"Invalid GTFS file(s): {', '.join(invalid)}. " f"Valid files are: {valid_list}", }) - return (normalized if normalized else list(STANDARD_GTFS_FILES)), None + if not normalized: + return list(STANDARD_GTFS_FILES), None + + return normalized, None -def query_gtfs_tool(feed_id: str, query: str, files: Optional[list[str]] = None) -> str: +def query_gtfs_tool(feed_id: str, query: str, files: list[str] = None) -> str: """ Load a GTFS feed into an in-memory DuckDB database and execute SQL queries. @@ -156,10 +172,10 @@ def query_gtfs_tool(feed_id: str, query: str, files: Optional[list[str]] = None) Args: feed_id: Mobility Database feed ID (e.g. "mdb-1210") query: Either "SCHEMA" to list tables/columns, or a SQL SELECT statement - files: Optional list of GTFS files to load (e.g. ["stops", "routes", "trips"]). + files: List of GTFS files to load (e.g. ["stops", "routes", "trips"]). Accepts table names or filenames (e.g. "stops" or "stops.txt"). - If omitted, all standard GTFS files are loaded. - Tip: only load the tables you need for much faster responses. + Required for SELECT queries — only load the tables you need for faster responses. + For SCHEMA queries, omit to discover all available tables. Returns: JSON string with schema info or query results @@ -181,7 +197,7 @@ def query_gtfs_tool(feed_id: str, query: str, files: Optional[list[str]] = None) started_at = time.perf_counter() try: - con = cache.get_or_load( + con, failed_files = cache.get_or_load( feed_id, f"{dataset_id}:{cache_key_suffix}", lambda: _load_duckdb(feed_id, dataset_id, datasets_bucket_url, target_files), @@ -197,22 +213,34 @@ def query_gtfs_tool(feed_id: str, query: str, files: Optional[list[str]] = None) default=str, ) - if (query or "").strip().upper() == "SCHEMA": - cursor = con.cursor() - try: - tables, available_files = _get_schema(cursor) - finally: - cursor.close() + if failed_files and len(failed_files) == len(target_files): return json.dumps( { "feed_id": feed_id, "dataset_id": dataset_id, - "tables": tables, - "available_files": available_files, + "error": "No GTFS files could be loaded for this feed. " + "The dataset may not have extracted files available.", + "failed_files": failed_files, }, default=str, ) + if (query or "").strip().upper() == "SCHEMA": + cursor = con.cursor() + try: + tables, available_files = _get_schema(cursor) + finally: + cursor.close() + result = { + "feed_id": feed_id, + "dataset_id": dataset_id, + "tables": tables, + "available_files": available_files, + } + if failed_files: + result["failed_files"] = failed_files + return json.dumps(result, default=str) + normalized_query = (query or "").strip().rstrip(";") if not normalized_query.upper().startswith("SELECT"): return json.dumps( diff --git a/mcp/src/tools/search_feeds.py b/mcp/src/tools/search_feeds.py index 3101c86e2..dd493b67b 100644 --- a/mcp/src/tools/search_feeds.py +++ b/mcp/src/tools/search_feeds.py @@ -16,7 +16,6 @@ def get_parsed_search_tsquery(search_query: str): def search_feeds_tool( search_query: str, data_type: Optional[str] = "gtfs", - status: Optional[str] = None, is_official: Optional[bool] = None, limit: Optional[int] = 30, ) -> str: @@ -29,7 +28,6 @@ def search_feeds_tool( Args: search_query: Free-text search (e.g., "Montreal", "Japan", "STM") data_type: One of: gtfs, gtfs_rt, gbfs. Default: gtfs - status: Feed status filter: active, deprecated, inactive. Default: no filter is_official: Filter for official feeds only limit: Max results to return. Default: 30 @@ -50,10 +48,6 @@ def search_feeds_tool( data_types = [dt.strip().lower() for dt in data_type.split(",")] query = query.where(t_feedsearch.c.data_type.in_(data_types)) - if status: - statuses = [s.strip().lower() for s in status.split(",")] - query = query.where(t_feedsearch.c.status.in_(statuses)) - if is_official is not None: if is_official: query = query.where(t_feedsearch.c.official.is_(True)) @@ -75,8 +69,6 @@ def search_feeds_tool( count_query = count_query.filter(t_feedsearch.c.operational_status == "published") if data_type: count_query = count_query.where(t_feedsearch.c.data_type.in_(data_types)) - if status: - count_query = count_query.where(t_feedsearch.c.status.in_(statuses)) if is_official is not None: if is_official: count_query = count_query.where(t_feedsearch.c.official.is_(True)) diff --git a/mcp/tests/test_query_gtfs.py b/mcp/tests/test_query_gtfs.py index 4a84c055a..6059c0369 100644 --- a/mcp/tests/test_query_gtfs.py +++ b/mcp/tests/test_query_gtfs.py @@ -92,7 +92,7 @@ def _call_tool(self, module, query="SCHEMA", feed_id="mdb-1210", bucket_url="htt del os.environ["DATASETS_BUCKET_URL"] try: if con is not None: - with patch.object(module, "_load_duckdb", return_value=con): + with patch.object(module, "_load_duckdb", return_value=(con, [])): return module.query_gtfs_tool(feed_id=feed_id, query=query, files=files) return module.query_gtfs_tool(feed_id=feed_id, query=query, files=files) finally: @@ -144,6 +144,7 @@ def test_select_query_returns_results(self): module, query="SELECT stop_id, stop_name FROM stops ORDER BY stop_id", con=con, + files=["stops"], ) ) assert result["columns"] == ["stop_id", "stop_name"] @@ -176,10 +177,10 @@ def test_cache_hit(self): con.execute("INSERT INTO stops VALUES ('1')") with patch.dict("os.environ", {"DATASETS_BUCKET_URL": "https://example.com"}), patch.object( - module, "_load_duckdb", return_value=con + module, "_load_duckdb", return_value=(con, []) ) as mock_loader: - first = json.loads(module.query_gtfs_tool(feed_id="mdb-1210", query="SELECT * FROM stops")) - second = json.loads(module.query_gtfs_tool(feed_id="mdb-1210", query="SELECT * FROM stops")) + first = json.loads(module.query_gtfs_tool(feed_id="mdb-1210", query="SELECT * FROM stops", files=["stops"])) + second = json.loads(module.query_gtfs_tool(feed_id="mdb-1210", query="SELECT * FROM stops", files=["stops"])) assert first["rows"] == [["1"]] assert second["rows"] == [["1"]] @@ -210,22 +211,26 @@ def test_files_parameter_accepts_txt_suffix(self): result = json.loads(self._call_tool(module, query="SCHEMA", con=con, files=["agency.txt"])) assert "agency" in result["tables"] - def test_query_infers_tables_when_files_not_provided(self): + def test_select_without_files_returns_error(self): module = self._import_tool() self.mock_filtered_query.first.return_value = self._make_feed(dataset=self._make_dataset()) - con = _make_con({"routes": [{"route_id": "10", "route_short_name": "10"}]}) - with patch.object(module, "_load_duckdb", return_value=con) as spy: - result = json.loads( - self._call_tool( - module, - query="SELECT * FROM routes", - con=None, - files=None, - ) + result = json.loads( + self._call_tool( + module, + query="SELECT * FROM routes", + con=None, + files=None, ) - loaded_files = spy.call_args[0][3] - assert loaded_files == ["routes.txt"] - assert result["columns"] == ["route_id", "route_short_name"] + ) + assert "error" in result + assert "'files' parameter is required" in result["error"] + + def test_schema_without_files_loads_all(self): + module = self._import_tool() + self.mock_filtered_query.first.return_value = self._make_feed(dataset=self._make_dataset()) + con = _make_con({"routes": [{"route_id": "10", "route_short_name": "10"}]}) + result = json.loads(self._call_tool(module, query="SCHEMA", con=con, files=None)) + assert "tables" in result def test_files_parameter_invalid_names_return_error(self): module = self._import_tool() diff --git a/mcp/tests/test_search_feeds.py b/mcp/tests/test_search_feeds.py index e9f3ba6dc..8dcfd9898 100644 --- a/mcp/tests/test_search_feeds.py +++ b/mcp/tests/test_search_feeds.py @@ -181,7 +181,6 @@ def test_result_values_correct(self): feed_stable_id="mdb-956", provider="STM", data_type="gtfs", - status="active", official=True, rank=0.95, ) @@ -192,7 +191,6 @@ def test_result_values_correct(self): assert feed["feed_id"] == "mdb-956" assert feed["provider"] == "STM" assert feed["data_type"] == "gtfs" - assert feed["status"] == "active" assert feed["is_official"] is True assert abs(feed["search_rank"] - 0.95) < 0.01 From bf255f2449605cf78489c2e9b701b16b2c9933c5 Mon Sep 17 00:00:00 2001 From: cka-y Date: Wed, 27 May 2026 08:37:12 -0400 Subject: [PATCH 5/5] [wip] mcp poc --- .github/workflows/db-deployer.yml | 27 ++++++++++++++++++++++++-- .github/workflows/db-prod.yml | 2 ++ .github/workflows/db-qa.yml | 2 ++ infra/mcp/main.tf | 4 ++-- infra/postgresql/main.tf | 19 ++++++++++++++++++ infra/postgresql/vars.tf | 10 ++++++++++ infra/postgresql/vars.tfvars.rename_me | 2 ++ mcp/README.md | 3 ++- mcp/src/tools/search_feeds.py | 14 ++++++++++--- mcp/tests/test_search_feeds.py | 19 ++++++++++++++++++ 10 files changed, 94 insertions(+), 8 deletions(-) diff --git a/.github/workflows/db-deployer.yml b/.github/workflows/db-deployer.yml index 07ecf82ed..5c5b454fb 100644 --- a/.github/workflows/db-deployer.yml +++ b/.github/workflows/db-deployer.yml @@ -15,6 +15,12 @@ on: POSTGRE_USER_PASSWORD: description: PostgreSQL User Password required: true + POSTGRE_READONLY_USER_NAME: + description: PostgreSQL Read-Only User Name (used by MCP server) + required: true + POSTGRE_READONLY_USER_PASSWORD: + description: PostgreSQL Read-Only User Password + required: true POSTGRE_SQL_INSTANCE_NAME: description: PostgreSQL Instance Name required: true @@ -93,13 +99,15 @@ jobs: echo "POSTGRE_SQL_DB_NAME=${{ inputs.POSTGRE_SQL_DB_NAME }}" >> $GITHUB_ENV echo "POSTGRE_USER_NAME=${{ secrets.POSTGRE_USER_NAME }}" >> $GITHUB_ENV echo "POSTGRE_USER_PASSWORD=${{ secrets.POSTGRE_USER_PASSWORD }}" >> $GITHUB_ENV + echo "POSTGRE_READONLY_USER_NAME=${{ secrets.POSTGRE_READONLY_USER_NAME }}" >> $GITHUB_ENV + echo "POSTGRE_READONLY_USER_PASSWORD=${{ secrets.POSTGRE_READONLY_USER_PASSWORD }}" >> $GITHUB_ENV echo "POSTGRE_INSTANCE_TIER=${{ inputs.POSTGRE_INSTANCE_TIER }}" >> $GITHUB_ENV echo "MAX_CONNECTIONS=${{ inputs.MAX_CONNECTIONS }}" >> $GITHUB_ENV - name: Populate Variables run: | scripts/replace-variables.sh -in_file infra/backend.conf.rename_me -out_file infra/postgresql/backend.conf -variables BUCKET_NAME,OBJECT_PREFIX - scripts/replace-variables.sh -in_file infra/postgresql/vars.tfvars.rename_me -out_file infra/postgresql/vars.tfvars -variables ENVIRONMENT,PROJECT_ID,REGION,DEPLOYER_SERVICE_ACCOUNT,POSTGRE_SQL_INSTANCE_NAME,POSTGRE_SQL_DB_NAME,POSTGRE_USER_NAME,POSTGRE_USER_PASSWORD,POSTGRE_INSTANCE_TIER,MAX_CONNECTIONS + scripts/replace-variables.sh -in_file infra/postgresql/vars.tfvars.rename_me -out_file infra/postgresql/vars.tfvars -variables ENVIRONMENT,PROJECT_ID,REGION,DEPLOYER_SERVICE_ACCOUNT,POSTGRE_SQL_INSTANCE_NAME,POSTGRE_SQL_DB_NAME,POSTGRE_USER_NAME,POSTGRE_USER_PASSWORD,POSTGRE_READONLY_USER_NAME,POSTGRE_READONLY_USER_PASSWORD,POSTGRE_INSTANCE_TIER,MAX_CONNECTIONS - name: Install Terraform uses: hashicorp/setup-terraform@v3 @@ -149,6 +157,8 @@ jobs: env: POSTGRE_USER_NAME: ${{ secrets.POSTGRE_USER_NAME }} POSTGRE_USER_PASSWORD: ${{ secrets.POSTGRE_USER_PASSWORD }} + POSTGRE_READONLY_USER_NAME: ${{ secrets.POSTGRE_READONLY_USER_NAME }} + POSTGRE_READONLY_USER_PASSWORD: ${{ secrets.POSTGRE_READONLY_USER_PASSWORD }} POSTGRE_SQL_DB_NAME: ${{ inputs.POSTGRE_SQL_DB_NAME }} DB_INSTANCE_HOST: ${{ needs.terraform.outputs.db_instance_host }} steps: @@ -165,7 +175,20 @@ jobs: SECRET_NAME="DEV_FEEDS_DATABASE_URL" SECRET_VALUE="postgresql://${{ env.POSTGRE_USER_NAME }}:${{ env.POSTGRE_USER_PASSWORD }}@${{ env.DB_INSTANCE_HOST }}/${{ env.POSTGRE_SQL_DB_NAME }}DEV" echo $SECRET_VALUE - + + if gcloud secrets describe $SECRET_NAME --project=mobility-feeds-dev; then + echo "Secret $SECRET_NAME already exists, updating..." + echo -n "$SECRET_VALUE" | gcloud secrets versions add $SECRET_NAME --data-file=- --project=mobility-feeds-dev + else + echo "Secret $SECRET_NAME does not exist, creating..." + echo -n "$SECRET_VALUE" | gcloud secrets create $SECRET_NAME --data-file=- --replication-policy="automatic" --project=mobility-feeds-dev + fi + + - name: Create or Update Readonly Secret in DEV + run: | + SECRET_NAME="DEV_FEEDS_DATABASE_URL_READONLY" + SECRET_VALUE="postgresql://${{ env.POSTGRE_READONLY_USER_NAME }}:${{ env.POSTGRE_READONLY_USER_PASSWORD }}@${{ env.DB_INSTANCE_HOST }}/${{ env.POSTGRE_SQL_DB_NAME }}DEV" + if gcloud secrets describe $SECRET_NAME --project=mobility-feeds-dev; then echo "Secret $SECRET_NAME already exists, updating..." echo -n "$SECRET_VALUE" | gcloud secrets versions add $SECRET_NAME --data-file=- --project=mobility-feeds-dev diff --git a/.github/workflows/db-prod.yml b/.github/workflows/db-prod.yml index 1a2d4ef42..874fad653 100644 --- a/.github/workflows/db-prod.yml +++ b/.github/workflows/db-prod.yml @@ -21,6 +21,8 @@ jobs: secrets: POSTGRE_USER_PASSWORD: ${{ secrets.PROD_POSTGRE_USER_PASSWORD }} POSTGRE_USER_NAME: ${{ secrets.PROD_POSTGRE_USER_NAME }} + POSTGRE_READONLY_USER_PASSWORD: ${{ secrets.PROD_POSTGRE_READONLY_USER_PASSWORD }} + POSTGRE_READONLY_USER_NAME: ${{ secrets.PROD_POSTGRE_READONLY_USER_NAME }} POSTGRE_SQL_INSTANCE_NAME: ${{ secrets.DB_INSTANCE_NAME }} GCP_MOBILITY_FEEDS_SA_KEY: ${{ secrets.PROD_GCP_MOBILITY_FEEDS_SA_KEY }} DEV_GCP_MOBILITY_FEEDS_SA_KEY: ${{ secrets.DEV_GCP_MOBILITY_FEEDS_SA_KEY }} \ No newline at end of file diff --git a/.github/workflows/db-qa.yml b/.github/workflows/db-qa.yml index 71331a6cd..26b559bad 100644 --- a/.github/workflows/db-qa.yml +++ b/.github/workflows/db-qa.yml @@ -21,6 +21,8 @@ jobs: secrets: POSTGRE_USER_PASSWORD: ${{ secrets.QA_POSTGRE_USER_PASSWORD }} POSTGRE_USER_NAME: ${{ secrets.QA_POSTGRE_USER_NAME }} + POSTGRE_READONLY_USER_PASSWORD: ${{ secrets.QA_POSTGRE_READONLY_USER_PASSWORD }} + POSTGRE_READONLY_USER_NAME: ${{ secrets.QA_POSTGRE_READONLY_USER_NAME }} POSTGRE_SQL_INSTANCE_NAME: ${{ secrets.DB_INSTANCE_NAME }} GCP_MOBILITY_FEEDS_SA_KEY: ${{ secrets.QA_GCP_MOBILITY_FEEDS_SA_KEY }} DEV_GCP_MOBILITY_FEEDS_SA_KEY: ${{ secrets.DEV_GCP_MOBILITY_FEEDS_SA_KEY }} \ No newline at end of file diff --git a/infra/mcp/main.tf b/infra/mcp/main.tf index 4acfdf1e8..ccc699c65 100644 --- a/infra/mcp/main.tf +++ b/infra/mcp/main.tf @@ -81,7 +81,7 @@ resource "google_cloud_run_v2_service" "mcp_server" { name = "FEEDS_DATABASE_URL" value_source { secret_key_ref { - secret = "${upper(var.environment)}_FEEDS_DATABASE_URL" + secret = "${upper(var.environment)}_FEEDS_DATABASE_URL_READONLY" version = "latest" } } @@ -118,7 +118,7 @@ resource "google_cloud_run_service_iam_policy" "noauth" { resource "google_secret_manager_secret_iam_member" "feeds_db_url_access" { project = var.project_id - secret_id = "${upper(var.environment)}_FEEDS_DATABASE_URL" + secret_id = "${upper(var.environment)}_FEEDS_DATABASE_URL_READONLY" role = "roles/secretmanager.secretAccessor" member = "serviceAccount:${google_service_account.mcp_service_account.email}" } diff --git a/infra/postgresql/main.tf b/infra/postgresql/main.tf index 2d1677e97..c46e5e185 100644 --- a/infra/postgresql/main.tf +++ b/infra/postgresql/main.tf @@ -95,6 +95,12 @@ resource "google_sql_user" "users" { password = var.postgresql_user_password } +resource "google_sql_user" "readonly_user" { + name = var.postgresql_readonly_user_name + instance = google_sql_database_instance.db.name + password = var.postgresql_readonly_user_password +} + resource "google_secret_manager_secret" "secret_db_url" { project = var.project_id secret_id = "${upper(var.environment)}_FEEDS_DATABASE_URL" @@ -108,6 +114,19 @@ resource "google_secret_manager_secret_version" "secret_version" { secret_data = "postgresql+psycopg2://${var.postgresql_user_name}:${var.postgresql_user_password}@${google_sql_database_instance.db.private_ip_address}/${var.postgresql_database_name}" } +resource "google_secret_manager_secret" "secret_db_url_readonly" { + project = var.project_id + secret_id = "${upper(var.environment)}_FEEDS_DATABASE_URL_READONLY" + replication { + auto {} + } +} + +resource "google_secret_manager_secret_version" "secret_version_readonly" { + secret = google_secret_manager_secret.secret_db_url_readonly.id + secret_data = "postgresql+psycopg2://${var.postgresql_readonly_user_name}:${var.postgresql_readonly_user_password}@${google_sql_database_instance.db.private_ip_address}/${var.postgresql_database_name}" +} + output "instance_address" { description = "The first public IPv4 address of the SQL instance" value = google_sql_database_instance.db.private_ip_address diff --git a/infra/postgresql/vars.tf b/infra/postgresql/vars.tf index 5647c0780..14c50062f 100644 --- a/infra/postgresql/vars.tf +++ b/infra/postgresql/vars.tf @@ -52,4 +52,14 @@ variable "postgresql_db_instance" { variable "max_db_connections" { type = string description = "Maximum number of database connections" +} + +variable "postgresql_readonly_user_name" { + type = string + description = "The name of the read-only PostgreSQL user (used by MCP server)" +} + +variable "postgresql_readonly_user_password" { + type = string + description = "The password for the read-only PostgreSQL user" } \ No newline at end of file diff --git a/infra/postgresql/vars.tfvars.rename_me b/infra/postgresql/vars.tfvars.rename_me index e0a669b92..0762d0496 100644 --- a/infra/postgresql/vars.tfvars.rename_me +++ b/infra/postgresql/vars.tfvars.rename_me @@ -14,5 +14,7 @@ postgresql_instance_name = {{POSTGRE_SQL_INSTANCE_NAME}} postgresql_database_name = {{POSTGRE_SQL_DB_NAME}} postgresql_user_name = {{POSTGRE_USER_NAME}} postgresql_user_password = {{POSTGRE_USER_PASSWORD}} +postgresql_readonly_user_name = {{POSTGRE_READONLY_USER_NAME}} +postgresql_readonly_user_password = {{POSTGRE_READONLY_USER_PASSWORD}} postgresql_db_instance = {{POSTGRE_INSTANCE_TIER}} max_db_connections = {{MAX_CONNECTIONS}} \ No newline at end of file diff --git a/mcp/README.md b/mcp/README.md index f43c778a3..f47d3c8e7 100644 --- a/mcp/README.md +++ b/mcp/README.md @@ -13,7 +13,8 @@ Searches the Mobility Database using PostgreSQL full-text search against the `Fe | `search_query` | string | — | Free-text search (e.g. `"Montreal"`, `"STM"`, `"Japan"`) | | `data_type` | string | `gtfs` | `gtfs`, `gtfs_rt`, or `gbfs` | | `is_official` | boolean | none | Filter to official feeds only | -| `limit` | integer | `30` | Max results | +| `limit` | integer | `30` | Max results per page | +| `offset` | integer | `0` | Number of results to skip (for pagination) | **Step 2 — Validation Results (`get_validation_results` tool)** diff --git a/mcp/src/tools/search_feeds.py b/mcp/src/tools/search_feeds.py index dd493b67b..c7f88789d 100644 --- a/mcp/src/tools/search_feeds.py +++ b/mcp/src/tools/search_feeds.py @@ -18,6 +18,7 @@ def search_feeds_tool( data_type: Optional[str] = "gtfs", is_official: Optional[bool] = None, limit: Optional[int] = 30, + offset: Optional[int] = 0, ) -> str: """ Search the Mobility Database for GTFS/GBFS/GTFS-RT feeds. @@ -25,14 +26,18 @@ def search_feeds_tool( Returns rich location and metadata context so the AI can disambiguate between results (e.g., Montreal Quebec vs Montréal-du-Gers France). + Supports pagination via limit and offset. The response includes total_matches so you + know how many more results are available. + Args: search_query: Free-text search (e.g., "Montreal", "Japan", "STM") data_type: One of: gtfs, gtfs_rt, gbfs. Default: gtfs is_official: Filter for official feeds only - limit: Max results to return. Default: 30 + limit: Max results per page. Default: 30 + offset: Number of results to skip for pagination. Default: 0 Returns: - JSON string with query, total_matches, and results array with feed metadata + JSON string with query, total_matches, offset, limit, and results array with feed metadata """ db = Database() with db.start_db_session() as session: @@ -79,7 +84,7 @@ def search_feeds_tool( if search_query and len(search_query.strip()) > 0: count_query = count_query.filter(t_feedsearch.c.document.op("@@")(ts_query)) - rows = session.execute(query.limit(limit)).fetchall() + rows = session.execute(query.offset(offset).limit(limit)).fetchall() total_count_result = session.execute(count_query).fetchone() total_count = total_count_result[0] if total_count_result else 0 @@ -90,6 +95,7 @@ def search_feeds_tool( "feed_id": row_dict.get("feed_stable_id"), "provider": row_dict.get("provider"), "feed_name": row_dict.get("feed_name"), + "producer_url": row_dict.get("producer_url"), "data_type": row_dict.get("data_type"), "status": row_dict.get("status"), "is_official": row_dict.get("official"), @@ -121,6 +127,8 @@ def search_feeds_tool( { "query": search_query, "total_matches": total_count, + "offset": offset, + "limit": limit, "results": results, }, default=str, diff --git a/mcp/tests/test_search_feeds.py b/mcp/tests/test_search_feeds.py index 8dcfd9898..5b17ef2fc 100644 --- a/mcp/tests/test_search_feeds.py +++ b/mcp/tests/test_search_feeds.py @@ -52,6 +52,7 @@ def make_feed_row( feed_stable_id="mdb-1", provider="Test Provider", feed_name="Test Feed", + producer_url="https://example.com/gtfs", data_type="gtfs", status="active", official=True, @@ -73,6 +74,7 @@ def make_feed_row( "feed_stable_id": feed_stable_id, "provider": provider, "feed_name": feed_name, + "producer_url": producer_url, "data_type": data_type, "status": status, "official": official, @@ -166,6 +168,7 @@ def test_result_schema(self): assert "feed_id" in feed assert "provider" in feed assert "feed_name" in feed + assert "producer_url" in feed assert "data_type" in feed assert "status" in feed assert "is_official" in feed @@ -215,6 +218,22 @@ def test_locations_in_result(self): feed = result["results"][0] assert feed["locations"] == locations + def test_pagination_offset_and_limit_in_response(self): + """Response includes offset and limit for pagination.""" + self.mock_session.execute.return_value.fetchall.return_value = [] + self.mock_session.execute.return_value.fetchone.return_value = (50,) + result = json.loads(self._call_tool("Montreal", offset=30, limit=10)) + assert result["offset"] == 30 + assert result["limit"] == 10 + assert result["total_matches"] == 50 + + def test_default_offset_is_zero(self): + """Default offset is 0 when not specified.""" + self.mock_session.execute.return_value.fetchall.return_value = [] + self.mock_session.execute.return_value.fetchone.return_value = (0,) + result = json.loads(self._call_tool("Montreal")) + assert result["offset"] == 0 + def test_empty_query_returns_all_feeds(self): """Empty search_query returns all feeds without text filter.""" self.mock_session.execute.return_value.fetchall.return_value = []