diff --git a/CHANGELOG.md b/CHANGELOG.md index b1c1fd19..d7f7620c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,42 @@ All notable changes to this project will be documented in this file. +## 1.4.0 + +### Added — Support for multiple models per version + +A dataset version can now own many trainings, and a training can produce many +models (e.g. a NAS sweep). New object types expose this: + +**SDK (`roboflow/core/training.py`, `roboflow/core/version.py`):** +- `Version.trainings()` — list the version's training runs as `Training` objects. +- `Version.models()` — every trained model for the version (the union across its + trainings), as `TrainedModel` objects. This is now the canonical way to get a + version's models. +- `Version.create_training(speed=, model_type=, checkpoint=, epochs=)` — launch a + run without blocking, returning a `Training`. +- `Training` — `.models`, `.refresh()`, `.cancel()`, `.stop()`, plus + `.training_id` / `.status` / `.model_type`. +- `TrainedModel` — `.predict()`, `.predict_video()`, `.download()`, plus + `.model_id` / `.model_type` / `.metrics`. A `TrainedModel` does everything the + old `version.model` could; you just reach it through `version.models()`. + +**Adapters (`roboflow/adapters/rfapi.py`):** v2 trainings endpoints — +`list_trainings_for_version`, `get_training`, `create_training_v2`, +`cancel_training_v2`, `stop_training_v2`, `get_model_weights_url`. + +### Changed + +- Keypoint detection inference now reports its prediction type correctly + (previously mislabeled as classification), fixing rendering/plotting of + keypoint predictions. + +### Deprecated + +- `version.model` (the singular attribute) is deprecated and emits a + `DeprecationWarning`. It cannot represent a version with multiple models; + use `version.models()` instead. + ## 1.3.10 ### Added diff --git a/README.md b/README.md index 5e034ff1..63b78ead 100644 --- a/README.md +++ b/README.md @@ -145,8 +145,8 @@ version = project.version("VERSION_NUMBER") # upload model weights - yolov10 version.deploy(model_type="yolov10", model_path=f”{HOME}/runs/detect/train/”, filename="weights.pt") -# run inference -model = version.model +# run inference (a version may own several trained models; models() returns all of them) +model = version.models()[0] img_url = "https://media.roboflow.com/quickstart/aerial_drone.jpeg" diff --git a/docs/core/training.md b/docs/core/training.md new file mode 100644 index 00000000..56f51fc4 --- /dev/null +++ b/docs/core/training.md @@ -0,0 +1 @@ +:::roboflow.core.training diff --git a/docs/index.md b/docs/index.md index 37cab58d..7fd48edf 100644 --- a/docs/index.md +++ b/docs/index.md @@ -94,8 +94,8 @@ version = project.version("VERSION_NUMBER") # upload model weights - yolov10 version.deploy(model_type="yolov10", model_path=f”{HOME}/runs/detect/train/”, filename="weights.pt") -# run inference -model = version.model +# run inference (a version may own several trained models; models() returns all of them) +model = version.models()[0] img_url = "https://media.roboflow.com/quickstart/aerial_drone.jpeg" diff --git a/mkdocs.yml b/mkdocs.yml index e6b4a5df..11543b97 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -34,6 +34,7 @@ nav: - Projects: core/project.md - Workspaces: core/workspace.md - Versions: core/version.md + - Trainings: core/training.md - Models: - Object Detection: models/object-detection.md - Classification: models/classification.md diff --git a/pyproject.toml b/pyproject.toml index 4de8c9ef..f84031c9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -107,6 +107,14 @@ banned-module-level-imports = [ python_version = "3.10" exclude = ["^build/"] +# numpy's bundled stubs use PEP 695 `type` statements, which mypy rejects when +# checking against python_version 3.10. Skip following them so the type checker +# doesn't choke on numpy's own stub syntax. +[[tool.mypy.overrides]] +module = ["numpy", "numpy.*"] +follow_imports = "skip" +follow_imports_for_stubs = true + [[tool.mypy.overrides]] module = [ "_datetime.*", diff --git a/roboflow/__init__.py b/roboflow/__init__.py index 560ffbbc..c248fa03 100644 --- a/roboflow/__init__.py +++ b/roboflow/__init__.py @@ -21,7 +21,7 @@ CLIPModel = None # type: ignore[assignment,misc] GazeModel = None # type: ignore[assignment,misc] -__version__ = "1.3.10" +__version__ = "1.4.0" def check_key(api_key, model, notebook, num_retries=0): @@ -168,7 +168,9 @@ def load_model(model_url): project = operate_workspace.project(project) version = project.version(version) - model = version.model + # version.model is deprecated; read the underlying legacy model directly so + # load_model keeps its single-model return contract without emitting the warning. + model = getattr(version, "_model", None) return model diff --git a/roboflow/adapters/rfapi.py b/roboflow/adapters/rfapi.py index eef2e2a7..b3fe2da1 100644 --- a/roboflow/adapters/rfapi.py +++ b/roboflow/adapters/rfapi.py @@ -155,6 +155,139 @@ def get_training_results(api_key: str, workspace_url: str, project_url: str, ver return response.json() +# --------------------------------------------------------------------------- +# DNA v2 trainings surface (MMPV-aware). Mirrors the MCP's rf_api.py 1:1: a +# version owns many trainings, each owning one or more models (a NAS run owns +# many). trainingId rides in the query/body, never the path, because legacy ids +# contain slashes. The legacy-vs-MMPV branch lives entirely on the backend. +# --------------------------------------------------------------------------- + + +def list_trainings_for_version(api_key: str, workspace_url: str, project_url: str, version: str): + """List a version's trainings (DNA ``trainings.list``). + + GET /{ws}/{proj}/{version}/v2/trainings. MMPV versions return every run; + SMPV versions return a single entry synthesized from ``version.train``. + Returns the raw ``trainings`` array — each entry carries + ``{trainingId, status, modelType, modelGroup, modelIds, start}``. + """ + url = f"{API_URL}/{workspace_url}/{project_url}/{version}/v2/trainings?api_key={api_key}" + response = requests.get(url) + if not response.ok: + raise RoboflowError(response.text) + data = response.json() + return data.get("trainings", []) or [] + + +def get_training(api_key: str, workspace_url: str, project_url: str, version: str, training_id=None): + """A single run's results bundle (DNA ``trainings.get``). + + GET /{ws}/{proj}/{version}/v2/trainings/get[?trainingId=]. Omitting + ``training_id`` targets the version's sole run; a version that owns several + runs responds 409 (list them and pass a specific id). Returns + ``{trainingId, status, modelType, modelGroup, modelCount, models: [...]}``, + each model carrying an inference-style ``modelId`` (``/``). + """ + url = f"{API_URL}/{workspace_url}/{project_url}/{version}/v2/trainings/get?api_key={api_key}" + if training_id: + url += f"&trainingId={quote(str(training_id), safe='')}" + response = requests.get(url) + if not response.ok: + raise RoboflowError(response.text) + return response.json() + + +def create_training_v2( + api_key: str, + workspace_url: str, + project_url: str, + version: str, + *, + speed: Optional[str] = None, + checkpoint: Optional[str] = None, + model_type: Optional[str] = None, + epochs: Optional[int] = None, +): + """Create a training on a version (DNA ``trainings.create``). + + POST /{ws}/{proj}/{version}/v2/trainings. A version may own many trainings, + so repeated/concurrent runs are allowed; the backend rejects a second run on + a legacy (SMPV) version. Returns ``{trainingId, status, jobId}``. + """ + url = f"{API_URL}/{workspace_url}/{project_url}/{version}/v2/trainings?api_key={api_key}" + data: Dict[str, Union[str, int]] = {} + if speed is not None: + data["speed"] = speed + if checkpoint is not None: + data["checkpoint"] = checkpoint + if model_type is not None: + data["modelType"] = model_type + if epochs is not None: + data["epochs"] = epochs + response = requests.post(url, json=data) + if not response.ok: + raise RoboflowError(response.text) + return response.json() if response.content else {"status": "training_started"} + + +def cancel_training_v2( + api_key: str, + workspace_url: str, + project_url: str, + version: str, + training_id=None, + continue_if_no_refund: bool = False, +): + """Cancel an in-flight run (DNA ``trainings.cancel``). + + POST /{ws}/{proj}/{version}/v2/trainings/cancel. ``training_id`` selects a + specific run; omit it to target the version's sole run. + """ + url = f"{API_URL}/{workspace_url}/{project_url}/{version}/v2/trainings/cancel?api_key={api_key}" + body: Dict[str, Union[str, bool]] = {} + if training_id: + body["trainingId"] = training_id + if continue_if_no_refund: + body["continueIfNoRefund"] = True + response = requests.post(url, json=body) + if not response.ok: + raise RoboflowError(response.text) + return response.json() if response.content else {"success": True} + + +def stop_training_v2(api_key: str, workspace_url: str, project_url: str, version: str, training_id=None): + """Request an early stop on an in-flight run (DNA ``trainings.stop``). + + POST /{ws}/{proj}/{version}/v2/trainings/stop. ``training_id`` selects a + specific run; omit it to target the version's sole run. + """ + url = f"{API_URL}/{workspace_url}/{project_url}/{version}/v2/trainings/stop?api_key={api_key}" + body: Dict[str, str] = {} + if training_id: + body["trainingId"] = training_id + response = requests.post(url, json=body) + if not response.ok: + raise RoboflowError(response.text) + return response.json() if response.content else {"success": True} + + +def get_model_weights_url(api_key: str, workspace_url: str, project_url: str, model_id: str, model_format: str = "pt"): + """Resolve a signed PyTorch weights URL for a single trained model. + + GET /{ws}/{proj}/{model_id}/ptFile, where ``model_id`` is the addressable + segment of an inference-style id — a model slug (MMPV) or a version number + (SMPV). Returns the signed ``weightsUrl``. + """ + if model_format != "pt": + raise RoboflowError(f"Unsupported weights format '{model_format}'. Only 'pt' is supported.") + encoded = quote(str(model_id), safe="") + url = f"{API_URL}/{workspace_url}/{project_url}/{encoded}/ptFile?api_key={api_key}" + response = requests.get(url) + if not response.ok: + raise RoboflowError(response.text) + return response.json()["weightsUrl"] + + def list_project_models( api_key: str, workspace_url: str, diff --git a/roboflow/cli/handlers/model.py b/roboflow/cli/handlers/model.py index 033d52dd..2f304fc7 100644 --- a/roboflow/cli/handlers/model.py +++ b/roboflow/cli/handlers/model.py @@ -196,15 +196,15 @@ def _list_models(args): # noqa: ANN001 models = [] for v in versions: - if v.model: + # version.model is deprecated; read the underlying legacy model directly. + v_model = getattr(v, "_model", None) + if v_model: models.append( { "version": v.version, "id": v.id, "model": getattr(v, "model_format", ""), - "map": getattr(v, "model", {}).get("map", "") - if isinstance(getattr(v, "model", None), dict) - else "", + "map": v_model.get("map", "") if isinstance(v_model, dict) else "", } ) diff --git a/roboflow/cli/handlers/video.py b/roboflow/cli/handlers/video.py index 0045189b..1fdc9905 100644 --- a/roboflow/cli/handlers/video.py +++ b/roboflow/cli/handlers/video.py @@ -56,7 +56,15 @@ def _video_infer(args) -> None: # noqa: ANN001 rf = roboflow.Roboflow(api_key) project = rf.workspace().project(args.project) version = project.version(args.version_number) - model = version.model + model = getattr(version, "_model", None) + if model is None: + output_error( + args, + f"No model found for project '{args.project}' version {args.version_number}.", + hint="Train or deploy a model for this version before running video inference.", + exit_code=3, + ) + return job_id, _signed_url, _expire_time = model.predict_video( args.video_file, diff --git a/roboflow/config.py b/roboflow/config.py index 0d569e95..800ef6ac 100644 --- a/roboflow/config.py +++ b/roboflow/config.py @@ -44,6 +44,7 @@ def get_conditional_configuration_variable(key, default): CLASSIFICATION_MODEL = os.getenv("CLASSIFICATION_MODEL", "ClassificationModel") INSTANCE_SEGMENTATION_MODEL = "InstanceSegmentationModel" +KEYPOINT_DETECTION_MODEL = "KeypointDetectionModel" OBJECT_DETECTION_MODEL = os.getenv("OBJECT_DETECTION_MODEL", "ObjectDetectionModel") SEMANTIC_SEGMENTATION_MODEL = "SemanticSegmentationModel" PREDICTION_OBJECT = os.getenv("PREDICTION_OBJECT", "Prediction") diff --git a/roboflow/core/training.py b/roboflow/core/training.py new file mode 100644 index 00000000..66880ddc --- /dev/null +++ b/roboflow/core/training.py @@ -0,0 +1,273 @@ +"""DNA-style Training / TrainedModel objects for MMPV (multiple-models-per-version). + +A Version owns many Trainings; each Training owns one or more Models (a NAS run +owns many). These objects couple to the v2 trainings adapter (``rfapi``), which +mirrors the platform's DNA operations 1:1 — the legacy-vs-MMPV branch lives on +the backend, never here. +""" + +from __future__ import annotations + +import json +import os +from typing import List + +import requests + +from roboflow.adapters import rfapi +from roboflow.config import ( + CLASSIFICATION_MODEL, + INSTANCE_SEGMENTATION_MODEL, + KEYPOINT_DETECTION_MODEL, + OBJECT_DETECTION_MODEL, + OBJECT_DETECTION_URL, + SEMANTIC_SEGMENTATION_MODEL, + SEMANTIC_SEGMENTATION_URL, + TASK_CLS, + TASK_OBB, + TASK_POSE, + TASK_SEG, + TASK_SEM, +) +from roboflow.models.inference import InferenceModel +from roboflow.util.model_processor import task_of_model_type + + +def _serverless_base_url_for_task(task: str) -> str: + if task == TASK_SEM: + return SEMANTIC_SEGMENTATION_URL + return OBJECT_DETECTION_URL + + +def _prediction_type_for_task(task: str) -> str: + if task == TASK_CLS: + return CLASSIFICATION_MODEL + elif task == TASK_SEG: + return INSTANCE_SEGMENTATION_MODEL + elif task == TASK_SEM: + return SEMANTIC_SEGMENTATION_MODEL + elif task == TASK_POSE: + return KEYPOINT_DETECTION_MODEL + elif task == TASK_OBB: + return OBJECT_DETECTION_MODEL + else: + return OBJECT_DETECTION_MODEL + + +class TrainedModel: + """A single trained model produced by a Training. + + Wraps an inference-style model id of either form — ``/`` + (SMPV) or ``/`` (MMPV). Inference goes to the + serverless host by that id (which the server resolves to the model and its + task); weights download keys off the id's addressable segment. + """ + + def __init__(self, api_key, workspace, project, model_id, model_type=None, metrics=None): + self.__api_key = api_key + self.workspace = workspace + self.project = project + self.model_id = model_id + self.model_type = model_type + self.metrics = metrics + # The second segment addresses the model on /ptFile: a model slug for + # MMPV, a version number for SMPV. + self._weights_id = model_id.split("/", 1)[1] if "/" in str(model_id) else model_id + self._video_model_cache = None + + def predict(self, image_path, hosted=False, confidence=40, overlap=30, format="json", **kwargs): + """Run hosted inference on an image by this model's id. + + The id is passed straight to serverless, which resolves the model and + its task. Returns a ``PredictionGroup``. Set ``hosted=True`` when + ``image_path`` is a public URL. + """ + task = task_of_model_type(self.model_type or "") + prediction_type = _prediction_type_for_task(task) + base_url = _serverless_base_url_for_task(task).rstrip("/") + model = InferenceModel(self.__api_key, "BASE_MODEL") + model.api_url = f"{base_url}/{str(self.model_id).strip('/')}" + model.colors = {} + + params = {"confidence": confidence, "overlap": overlap, "format": format} + params.update(kwargs) + return model.predict(image_path, prediction_type=prediction_type, **params) + + def _video_model(self): + """Build (and cache) the legacy inference model used for video inference. + + Video upload and result polling still flow through the legacy + ``/videoinfer`` endpoints, which the task-specific models implement. + Caching keeps ``predict_video`` and the poll methods on one underlying + object, so a job started here can be polled without re-passing its id. + """ + if self._video_model_cache is not None: + return self._video_model_cache + + from roboflow.models.classification import ClassificationModel + from roboflow.models.instance_segmentation import InstanceSegmentationModel + from roboflow.models.keypoint_detection import KeypointDetectionModel + from roboflow.models.object_detection import ObjectDetectionModel + from roboflow.models.semantic_segmentation import SemanticSegmentationModel + + task = task_of_model_type(self.model_type or "") + legacy_class = { + TASK_CLS: ClassificationModel, + TASK_SEG: InstanceSegmentationModel, + TASK_SEM: SemanticSegmentationModel, + TASK_POSE: KeypointDetectionModel, + }.get(task, ObjectDetectionModel) + + legacy_id = f"{self.workspace}/{self.project}/{self._weights_id}" + self._video_model_cache = legacy_class(self.__api_key, legacy_id) + return self._video_model_cache + + def predict_video(self, video_path, fps=5, additional_models=None, prediction_type="batch-video"): + """Run hosted video inference for this model (DNA-era equivalent of the + legacy ``version.model.predict_video``). + + Delegates to the task-appropriate legacy inference model built from this + model's id, so a ``TrainedModel`` can do everything the old + ``version.model`` could. Returns ``(job_id, signed_url, expires)``; poll + with :meth:`poll_until_video_results` on the same object. + + NOTE: the legacy ``/videoinfer`` payload is keyed by ``/``. + For MMPV models addressed by ``/`` this routes the + slug through as the version segment; verify against staging before relying + on it for slug-addressed models. + """ + return self._video_model().predict_video( + video_path, fps=fps, additional_models=additional_models, prediction_type=prediction_type + ) + + def poll_for_video_results(self, job_id=None) -> dict: + """Check once for this model's video inference results (DNA-era equivalent + of the legacy ``version.model.poll_for_video_results``). + + Returns ``{}`` while the job is still running. Defaults to the job started + by the most recent :meth:`predict_video` call on this object. + """ + return self._video_model().poll_for_video_results(job_id) + + def poll_until_video_results(self, job_id=None) -> dict: + """Block until this model's video inference job completes, returning the + results (DNA-era equivalent of the legacy + ``version.model.poll_until_video_results``). + + Defaults to the job started by the most recent :meth:`predict_video` call + on this object. + """ + return self._video_model().poll_until_video_results(job_id) + + def download(self, format="pt", location="."): + """Download this model's PyTorch weights to ``location/weights.pt``.""" + weights_url = rfapi.get_model_weights_url( + self.__api_key, self.workspace, self.project, self._weights_id, model_format=format + ) + os.makedirs(location, exist_ok=True) + out_path = os.path.join(location, "weights.pt") + response = requests.get(weights_url, stream=True) + response.raise_for_status() + with open(out_path, "wb") as f: + for chunk in response.iter_content(chunk_size=8192): + if chunk: + f.write(chunk) + return out_path + + def __str__(self): + return json.dumps( + {"model_id": self.model_id, "model_type": self.model_type, "metrics": self.metrics}, + indent=2, + ) + + +class Training: + """One training run on a dataset version. + + A version may own many trainings; a NAS run produces many models. Couples to + the v2 trainings adapter — ``.models`` resolves the run's produced models via + ``trainings.get``. + """ + + def __init__(self, api_key, workspace, project, version, raw): + self.__api_key = api_key + self.workspace = workspace + self.project = project + self.version = version + self._raw = raw or {} + self.training_id = self._raw.get("trainingId") or self._raw.get("id") + self.status = self._raw.get("status") + self.model_type = self._raw.get("modelType") + self.model_group = self._raw.get("modelGroup") + self.model_ids = self._raw.get("modelIds", []) or [] + self._models_cache = None + + @property + def models(self) -> List["TrainedModel"]: + """The models this run produced (DNA ``trainings.get`` → ``models[]``).""" + if self._models_cache is not None: + return self._models_cache + + bundle = rfapi.get_training( + self.__api_key, self.workspace, self.project, self.version, training_id=self.training_id + ) + bundle_model_type = bundle.get("modelType") or self.model_type + models = [] + for entry in bundle.get("models", []) or []: + model_id = entry.get("modelId") + if not model_id: + continue + models.append( + TrainedModel( + self.__api_key, + self.workspace, + self.project, + model_id, + model_type=entry.get("modelType") or bundle_model_type, + metrics=entry.get("metrics"), + ) + ) + self._models_cache = models + return self._models_cache + + def refresh(self) -> "Training": + """Re-read this run's status/results from the backend in place.""" + bundle = rfapi.get_training( + self.__api_key, self.workspace, self.project, self.version, training_id=self.training_id + ) + self._raw.update(bundle) + self.training_id = bundle.get("trainingId") or bundle.get("id") or self.training_id + self.status = bundle.get("status", self.status) + self.model_type = bundle.get("modelType", self.model_type) + self.model_group = bundle.get("modelGroup", self.model_group) + self.model_ids = bundle.get("modelIds", self.model_ids) + self._models_cache = None + return self + + def cancel(self, continue_if_no_refund: bool = False): + """Cancel this run immediately (DNA ``trainings.cancel``).""" + return rfapi.cancel_training_v2( + self.__api_key, + self.workspace, + self.project, + self.version, + training_id=self.training_id, + continue_if_no_refund=continue_if_no_refund, + ) + + def stop(self): + """Request a graceful early stop on this run (DNA ``trainings.stop``).""" + return rfapi.stop_training_v2( + self.__api_key, self.workspace, self.project, self.version, training_id=self.training_id + ) + + def __str__(self): + return json.dumps( + { + "training_id": self.training_id, + "status": self.status, + "model_type": self.model_type, + "model_group": self.model_group, + }, + indent=2, + ) diff --git a/roboflow/core/version.py b/roboflow/core/version.py index fd4a3f5f..e52fc538 100644 --- a/roboflow/core/version.py +++ b/roboflow/core/version.py @@ -5,6 +5,7 @@ import os import sys import time +import warnings from typing import TYPE_CHECKING, Optional, Union import requests @@ -50,8 +51,6 @@ class Version: Class representing a Roboflow dataset version. """ - model: Optional[InferenceModel] - def __init__( self, version_dict, @@ -94,12 +93,11 @@ def __init__( version_without_workspace = os.path.basename(str(version)) - try: - version_response = rfapi.get_version(self.__api_key, workspace, project, self.version) - version_info = version_response.get("version", {}) - has_model = bool(version_info.get("train", {}).get("model")) - except rfapi.RoboflowError: - has_model = False + # Derive the legacy single-model flag from the payload the caller + # already fetched. Keeping __init__ free of network side effects means + # a transient/mocked request failure can't break basic version + # retrieval; the v2 surface (models()/trainings()) does its own reads. + has_model = bool(version_dict.get("model")) if not has_model: self.model = None @@ -162,6 +160,80 @@ def __init__( self.version = "23" self.id = "joseph-nelson/chess-pieces-new" + @property + def model(self): + """Deprecated. The version's legacy single inference model, or ``None``. + + A version may now own many trained models (MMPV). This single-model + attribute cannot represent that, so it is deprecated in favor of + :meth:`models`, which returns every trained model for the version, and + :meth:`trainings`, which exposes the runs that produced them. + """ + warnings.warn( + "version.model is deprecated and will be removed in a future release; " + "use version.models() (all trained models) or version.trainings() instead.", + DeprecationWarning, + stacklevel=2, + ) + return getattr(self, "_model", None) + + @model.setter + def model(self, value): + self._model = value + + def trainings(self): + """List this version's trainings as Training objects (DNA ``trainings.list``). + + An MMPV version may own many; a legacy (SMPV) version reports its single + run. Returns a list of :class:`~roboflow.core.training.Training`. + """ + from roboflow.core.training import Training + + raw = rfapi.list_trainings_for_version(self.__api_key, self.workspace, self.project, self.version) + return [Training(self.__api_key, self.workspace, self.project, self.version, t) for t in raw] + + def models(self): + """All trained models for this version — the union across its trainings. + + Mirrors the backend's "a version's models are the union across its + trainings" rule. Returns a list of + :class:`~roboflow.core.training.TrainedModel`. + """ + result = [] + for training in self.trainings(): + result.extend(training.models) + return result + + def create_training(self, speed=None, model_type=None, checkpoint=None, epochs=None): + """Create a v2 training run and return a Training object. + + Unlike :meth:`train`, this does not block until completion or return a + legacy task-specific model. It exposes the MMPV-aware training id so + callers can refresh the run, enumerate produced models, and select the + model they want. + """ + from roboflow.core.training import Training + + self.__wait_if_generating() + + if model_type: + train_model_format = get_model_format(model_type) + if train_model_format not in self.exports: + self.export(train_model_format) + + workspace, project, *_ = self.id.rsplit("/") + raw = rfapi.create_training_v2( + api_key=self.__api_key, + workspace_url=workspace, + project_url=project, + version=self.version, + speed=speed if speed else None, + checkpoint=checkpoint if checkpoint else None, + model_type=model_type if model_type else None, + epochs=epochs, + ) + return Training(self.__api_key, workspace, project, self.version, raw) + def __check_if_generating(self): # check Roboflow API to see if this version is still generating versiondict = rfapi.get_version( @@ -451,7 +523,7 @@ def live_plot(epochs, mAP, loss, title=""): time.sleep(5) - if not self.model: + if not getattr(self, "_model", None): if self.type == TYPE_OBJECT_DETECTION: self.model = ObjectDetectionModel( self.__api_key, @@ -485,8 +557,8 @@ def live_plot(epochs, mAP, loss, title=""): raise ValueError(f"Unsupported model type: {self.type}") # return the model object - assert self.model - return self.model + assert self._model + return self._model # @warn_for_wrong_dependencies_versions([("ultralytics", "==", "8.0.196")]) def deploy(self, model_type: str, model_path: str, filename: str = "weights/best.pt") -> None: diff --git a/roboflow/core/workspace.py b/roboflow/core/workspace.py index c1dca153..5589f58f 100644 --- a/roboflow/core/workspace.py +++ b/roboflow/core/workspace.py @@ -677,8 +677,11 @@ def active_learning( else: local = None - inference_model = ( - self.project(inference_endpoint[0]).version(version_number=inference_endpoint[1], local=local).model + # version.model is deprecated; read the underlying legacy model directly. + inference_model = getattr( + self.project(inference_endpoint[0]).version(version_number=inference_endpoint[1], local=local), + "_model", + None, ) upload_project = self.project(upload_destination) @@ -718,7 +721,7 @@ def active_learning( print(image2 + " --> similarity too high to --> " + image1) continue # skip this image if too similar or counter hits limit - predictions = inference_model.predict(image).json()["predictions"] # type: ignore[attribute-error] + predictions = inference_model.predict(image).json()["predictions"] # type: ignore[union-attr] # collect all predictions to return to user at end prediction_results.append({"image": image, "predictions": predictions}) diff --git a/roboflow/models/classification.py b/roboflow/models/classification.py index 191e3ece..15c8be94 100644 --- a/roboflow/models/classification.py +++ b/roboflow/models/classification.py @@ -81,7 +81,7 @@ def predict(self, image_path, hosted=False): # type: ignore[override] >>> project = rf.workspace().project("PROJECT_ID") - >>> model = project.version("1").model + >>> model = project.version("1").models()[0] >>> prediction = model.predict("YOUR_IMAGE.jpg") """ diff --git a/roboflow/models/inference.py b/roboflow/models/inference.py index ca3f4503..0a0fffc7 100644 --- a/roboflow/models/inference.py +++ b/roboflow/models/inference.py @@ -124,7 +124,7 @@ def predict(self, image_path, prediction_type=None, **kwargs): >>> project = rf.workspace().project("PROJECT_ID") - >>> model = project.version("1").model + >>> model = project.version("1").models()[0] >>> prediction = model.predict("YOUR_IMAGE.jpg") """ @@ -170,7 +170,7 @@ def predict_video( >>> project = rf.workspace().project("PROJECT_ID") - >>> model = project.version("1").model + >>> model = project.version("1").models()[0] >>> job_id,signed_url,signed_url_expires = model.predict_video("video.mp4" ,fps=5, inference_type="object-detection") @@ -307,7 +307,7 @@ def poll_for_video_results(self, job_id: Optional[str] = None) -> dict: >>> project = rf.workspace().project("PROJECT_ID") - >>> model = project.version("1").model + >>> model = project.version("1").models()[0] >>> prediction = model.predict("video.mp4") @@ -355,7 +355,7 @@ def poll_until_video_results(self, job_id) -> dict: >>> project = rf.workspace().project("PROJECT_ID") - >>> model = project.version("1").model + >>> model = project.version("1").models()[0] >>> prediction = model.predict("video.mp4") diff --git a/roboflow/models/instance_segmentation.py b/roboflow/models/instance_segmentation.py index b26c1f36..a04ccc8e 100644 --- a/roboflow/models/instance_segmentation.py +++ b/roboflow/models/instance_segmentation.py @@ -53,7 +53,7 @@ def predict(self, image_path, confidence=40): # type: ignore[override] >>> project = rf.workspace().project("PROJECT_ID") - >>> model = project.version("1").model + >>> model = project.version("1").models()[0] >>> prediction = model.predict("YOUR_IMAGE.jpg") """ # noqa: E501 diff --git a/roboflow/models/keypoint_detection.py b/roboflow/models/keypoint_detection.py index f97dc4e5..c3b7321e 100644 --- a/roboflow/models/keypoint_detection.py +++ b/roboflow/models/keypoint_detection.py @@ -8,7 +8,7 @@ import requests from PIL import Image -from roboflow.config import CLASSIFICATION_MODEL +from roboflow.config import KEYPOINT_DETECTION_MODEL from roboflow.models.inference import InferenceModel from roboflow.util.image_utils import check_image_url from roboflow.util.prediction import PredictionGroup @@ -52,6 +52,7 @@ def __init__( self.name = name self.confidence = confidence self.version = version + self.colors = {} self.base_url = "https://serverless.roboflow.com/" if self.name is not None and version is not None: @@ -79,7 +80,7 @@ def predict(self, image_path, hosted=False, confidence=None): # type: ignore[ov >>> project = rf.workspace().project("PROJECT_ID") - >>> model = project.version("1").model + >>> model = project.version("1").models()[0] >>> prediction = model.predict("YOUR_IMAGE.jpg") """ @@ -119,7 +120,7 @@ def predict(self, image_path, hosted=False, confidence=None): # type: ignore[ov resp.json(), image_dims=img_dims, image_path=image_path, - prediction_type=CLASSIFICATION_MODEL, + prediction_type=KEYPOINT_DETECTION_MODEL, colors=self.colors, ) diff --git a/roboflow/models/object_detection.py b/roboflow/models/object_detection.py index 559b2f2d..5793ec86 100644 --- a/roboflow/models/object_detection.py +++ b/roboflow/models/object_detection.py @@ -152,7 +152,7 @@ def predict( # type: ignore[override] >>> project = rf.workspace().project("PROJECT_ID") - >>> model = project.version("1").model + >>> model = project.version("1").models()[0] >>> prediction = model.predict("YOUR_IMAGE.jpg") """ diff --git a/roboflow/models/semantic_segmentation.py b/roboflow/models/semantic_segmentation.py index c15b0c74..5dfd5659 100644 --- a/roboflow/models/semantic_segmentation.py +++ b/roboflow/models/semantic_segmentation.py @@ -36,7 +36,7 @@ def predict(self, image_path: str, confidence: int = 50): # type: ignore[overri >>> project = rf.workspace().project("PROJECT_ID") - >>> model = project.version("1").model + >>> model = project.version("1").models()[0] >>> prediction = model.predict("YOUR_IMAGE.jpg") """ # noqa: E501 // docs diff --git a/roboflow/models/video.py b/roboflow/models/video.py index 401a2aab..e1cae97b 100644 --- a/roboflow/models/video.py +++ b/roboflow/models/video.py @@ -90,7 +90,7 @@ def predict( # type: ignore[override] >>> project = rf.workspace().project("PROJECT_ID") - >>> model = project.version("1").model + >>> model = project.version("1").models()[0] >>> prediction = model.predict("video.mp4", fps=5, inference_type="object-detection") """ # noqa: E501 // docs @@ -164,7 +164,7 @@ def poll_for_results(self, job_id: Optional[str] = None) -> dict: >>> project = rf.workspace().project("PROJECT_ID") - >>> model = project.version("1").model + >>> model = project.version("1").models()[0] >>> prediction = model.predict("video.mp4") @@ -216,7 +216,7 @@ def poll_until_results(self, job_id) -> dict: >>> project = rf.workspace().project("PROJECT_ID") - >>> model = project.version("1").model + >>> model = project.version("1").models()[0] >>> prediction = model.predict("video.mp4") diff --git a/roboflow/util/prediction.py b/roboflow/util/prediction.py index d4740e58..77d4cd73 100644 --- a/roboflow/util/prediction.py +++ b/roboflow/util/prediction.py @@ -10,6 +10,7 @@ from roboflow.config import ( CLASSIFICATION_MODEL, INSTANCE_SEGMENTATION_MODEL, + KEYPOINT_DETECTION_MODEL, OBJECT_DETECTION_MODEL, PREDICTION_OBJECT, SEMANTIC_SEGMENTATION_MODEL, @@ -57,7 +58,7 @@ def plot_annotation(axes, prediction=None, stroke=1, transparency=60, colors=Non prediction = prediction or {} stroke_color = "r" - if prediction["prediction_type"] == OBJECT_DETECTION_MODEL: + if prediction["prediction_type"] in (OBJECT_DETECTION_MODEL, KEYPOINT_DETECTION_MODEL): if prediction["class"] in colors.keys(): stroke_color = colors[prediction["class"]] @@ -158,7 +159,7 @@ def save(self, output_path="predictions.jpg", stroke=2, transparency=60): image = self.__load_image() stroke_color = (255, 0, 0) - if self["prediction_type"] == OBJECT_DETECTION_MODEL: + if self["prediction_type"] in (OBJECT_DETECTION_MODEL, KEYPOINT_DETECTION_MODEL): # Get different dimensions/coordinates x = self["x"] y = self["y"] @@ -346,7 +347,7 @@ def save(self, output_path="predictions.jpg", stroke=2): # Iterate through predictions and add prediction to image for prediction in self.predictions: # Check what type of prediction it is - if self.base_prediction_type == OBJECT_DETECTION_MODEL: + if self.base_prediction_type in (OBJECT_DETECTION_MODEL, KEYPOINT_DETECTION_MODEL): # Get different dimensions/coordinates x = prediction["x"] y = prediction["y"] @@ -509,7 +510,7 @@ def create_prediction_group(json_response, image_path, prediction_type, image_di colors = {} if colors is None else colors prediction_list = [] - if prediction_type in [OBJECT_DETECTION_MODEL, INSTANCE_SEGMENTATION_MODEL]: + if prediction_type in [OBJECT_DETECTION_MODEL, INSTANCE_SEGMENTATION_MODEL, KEYPOINT_DETECTION_MODEL]: for prediction in json_response["predictions"]: prediction = Prediction( prediction, diff --git a/tests/models/test_keypoint_detection.py b/tests/models/test_keypoint_detection.py index e7c8accd..f49be85f 100644 --- a/tests/models/test_keypoint_detection.py +++ b/tests/models/test_keypoint_detection.py @@ -6,6 +6,7 @@ import responses from dotenv import load_dotenv +from roboflow.config import KEYPOINT_DETECTION_MODEL from roboflow.models.keypoint_detection import KeypointDetectionModel from roboflow.util.prediction import PredictionGroup @@ -46,7 +47,9 @@ def test_predict_local_image(self): result = instance.predict("tests/images/MM2A_46_R_T.png") self.assertIsInstance(result, PredictionGroup) - self.assertEqual(len(result.predictions), 1) + self.assertEqual(len(result.predictions), len(MOCK_RESPONSE["predictions"])) + self.assertEqual(result.predictions[0]["prediction_type"], KEYPOINT_DETECTION_MODEL) + self.assertIn("keypoints", result.predictions[0].json()) @responses.activate def test_predict_with_confidence(self): diff --git a/tests/test_training.py b/tests/test_training.py new file mode 100644 index 00000000..89420c1d --- /dev/null +++ b/tests/test_training.py @@ -0,0 +1,121 @@ +import unittest +from unittest.mock import patch + +from roboflow.config import ( + CLASSIFICATION_MODEL, + INSTANCE_SEGMENTATION_MODEL, + KEYPOINT_DETECTION_MODEL, + OBJECT_DETECTION_MODEL, + SEMANTIC_SEGMENTATION_MODEL, +) +from roboflow.core.training import TrainedModel, Training +from roboflow.models.classification import ClassificationModel +from roboflow.models.instance_segmentation import InstanceSegmentationModel +from roboflow.models.keypoint_detection import KeypointDetectionModel +from roboflow.models.object_detection import ObjectDetectionModel +from roboflow.models.semantic_segmentation import SemanticSegmentationModel + + +class TestTrainedModelPredict(unittest.TestCase): + def test_predict_routes_through_shared_inference_model_with_task_prediction_type(self): + cases = [ + ("yolov11", OBJECT_DETECTION_MODEL, "https://serverless.roboflow.com/ws/model-slug"), + ("yolov11-cls", CLASSIFICATION_MODEL, "https://serverless.roboflow.com/ws/model-slug"), + ("yolov11-seg", INSTANCE_SEGMENTATION_MODEL, "https://serverless.roboflow.com/ws/model-slug"), + ("yolov11-pose", KEYPOINT_DETECTION_MODEL, "https://serverless.roboflow.com/ws/model-slug"), + ("yolo26-sem", SEMANTIC_SEGMENTATION_MODEL, "https://segment.roboflow.com/ws/model-slug"), + ] + + for model_type, prediction_type, api_url in cases: + with self.subTest(model_type=model_type): + model = TrainedModel("key", "ws", "proj", "ws/model-slug", model_type=model_type) + with patch( + "roboflow.core.training.InferenceModel.predict", + autospec=True, + return_value="ok", + ) as predict: + result = model.predict("image.jpg", confidence=17, overlap=9, format="json") + + inference_model = predict.call_args.args[0] + self.assertEqual(result, "ok") + self.assertEqual(inference_model.api_url, api_url) + self.assertEqual(predict.call_args.kwargs["prediction_type"], prediction_type) + self.assertEqual(predict.call_args.kwargs["confidence"], 17) + self.assertEqual(predict.call_args.kwargs["overlap"], 9) + self.assertEqual(predict.call_args.kwargs["format"], "json") + + +class TestTrainedModelVideo(unittest.TestCase): + def test_predict_video_routes_through_task_appropriate_legacy_model(self): + cases = [ + ("yolov11", ObjectDetectionModel), + ("yolov11-cls", ClassificationModel), + ("yolov11-seg", InstanceSegmentationModel), + ("yolov11-pose", KeypointDetectionModel), + ("yolo26-sem", SemanticSegmentationModel), + ] + + for model_type, legacy_class in cases: + with self.subTest(model_type=model_type): + model = TrainedModel("key", "ws", "proj", "ws/model-slug", model_type=model_type) + with patch.object( + legacy_class, + "predict_video", + autospec=True, + return_value=("job-1", "signed-url", None), + ) as predict_video: + result = model.predict_video("video.mp4", fps=9) + + legacy_model = predict_video.call_args.args[0] + self.assertIsInstance(legacy_model, legacy_class) + self.assertEqual(legacy_model.id, "ws/proj/model-slug") + self.assertEqual(result, ("job-1", "signed-url", None)) + self.assertEqual(predict_video.call_args.kwargs["fps"], 9) + + def test_poll_reuses_the_predict_video_legacy_model(self): + model = TrainedModel("key", "ws", "proj", "ws/model-slug", model_type="yolov11") + + with ( + patch.object(ObjectDetectionModel, "predict_video", autospec=True, return_value=("job-1", "url", None)), + patch.object( + ObjectDetectionModel, "poll_until_video_results", autospec=True, return_value={"frames": []} + ) as poll, + ): + model.predict_video("video.mp4") + result = model.poll_until_video_results("job-1") + + self.assertEqual(result, {"frames": []}) + self.assertIs(poll.call_args.args[0], model._video_model()) + + +class TestTrainingModels(unittest.TestCase): + def test_models_are_cached_until_refresh(self): + training = Training("key", "ws", "proj", "1", {"trainingId": "training-1"}) + bundle = { + "status": "finished", + "modelType": "yolov11-cls", + "modelGroup": "group-1", + "modelIds": ["ws/model-slug"], + "models": [{"modelId": "ws/model-slug"}], + } + + with patch("roboflow.core.training.rfapi.get_training", return_value=bundle) as get_training: + first = training.models + second = training.models + training.refresh() + third = training.models + + self.assertIs(first, second) + self.assertEqual(first[0].model_id, "ws/model-slug") + self.assertEqual(first[0].model_type, "yolov11-cls") + self.assertEqual(third[0].model_id, "ws/model-slug") + self.assertEqual(third[0].model_type, "yolov11-cls") + self.assertEqual(training.status, "finished") + self.assertEqual(training.model_type, "yolov11-cls") + self.assertEqual(training.model_group, "group-1") + self.assertEqual(training.model_ids, ["ws/model-slug"]) + self.assertEqual(get_training.call_count, 3) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_version.py b/tests/test_version.py index 64b8874e..fefc9597 100644 --- a/tests/test_version.py +++ b/tests/test_version.py @@ -1,7 +1,9 @@ import os import unittest -from unittest.mock import patch +from types import SimpleNamespace +from unittest.mock import MagicMock, patch +import requests import responses from roboflow.adapters import rfapi @@ -13,6 +15,7 @@ TYPE_SEMANTIC_SEGMENTATION, ) from roboflow.core.version import Version, unwrap_version_id +from roboflow.models.object_detection import ObjectDetectionModel from tests.helpers import get_version @@ -266,3 +269,74 @@ def test_detection_project_rejects_sem_model(self): def test_classification_project_rejects_detection(self): with self.assertRaises(ValueError): self._version(TYPE_CLASSICATION)._validate_against_project_type("yolov11") + + +class TestConstructionDoesNotProbeNetwork(unittest.TestCase): + @patch("roboflow.adapters.rfapi.get_version", side_effect=AssertionError("get_version should not be called")) + def test_construction_makes_no_request_when_payload_has_no_model(self, _mock_get_version: MagicMock): + version = get_version() + self.assertIsNone(version._model) + + @patch( + "roboflow.adapters.rfapi.get_version", + side_effect=requests.exceptions.ConnectionError("network down"), + ) + def test_construction_survives_request_layer_failure(self, _mock_get_version: MagicMock): + # A transient/mocked request failure must not break basic version retrieval. + version = get_version() + self.assertIsNone(version._model) + + @patch("roboflow.adapters.rfapi.get_version", side_effect=AssertionError("get_version should not be called")) + def test_legacy_model_is_derived_from_payload(self, _mock_get_version: MagicMock): + version = get_version(type=TYPE_OBJECT_DETECTION, model={"id": "test-workspace/test-project/2"}) + self.assertIsInstance(version._model, ObjectDetectionModel) + + +class TestMMPVCompatibility(unittest.TestCase): + @patch("roboflow.adapters.rfapi.get_version", return_value={"version": {}}) + def test_model_property_is_deprecated_and_does_not_enumerate_models(self, _mock_get_version: MagicMock): + version = get_version() + with patch.object(Version, "models", side_effect=AssertionError("models should not be called")): + with self.assertWarns(DeprecationWarning): + self.assertIsNone(version.model) + + @patch("roboflow.adapters.rfapi.get_version", return_value={"version": {}}) + def test_models_returns_union_across_trainings(self, _mock_get_version: MagicMock): + version = get_version() + a, b, c = object(), object(), object() + training_one = SimpleNamespace(models=[a, b]) + training_two = SimpleNamespace(models=[c]) + with patch.object(Version, "trainings", return_value=[training_one, training_two]): + self.assertEqual(version.models(), [a, b, c]) + + @patch.object(Version, "_Version__wait_if_generating") + @patch("roboflow.adapters.rfapi.create_training_v2") + @patch("roboflow.adapters.rfapi.get_version", return_value={"version": {}}) + def test_create_training_returns_v2_training( + self, + _mock_get_version: MagicMock, + mock_create_training: MagicMock, + _mock_wait_if_generating: MagicMock, + ): + mock_create_training.return_value = { + "trainingId": "training-1", + "status": "running", + "modelType": "yolov11", + } + version = get_version(version_number="4") + + training = version.create_training(speed="fast", model_type=None, checkpoint="ckpt", epochs=10) + + mock_create_training.assert_called_once_with( + api_key="test-api-key", + workspace_url="test-workspace", + project_url="test-project", + version="4", + speed="fast", + checkpoint="ckpt", + model_type=None, + epochs=10, + ) + self.assertEqual(training.training_id, "training-1") + self.assertEqual(training.status, "running") + self.assertEqual(training.model_type, "yolov11")