Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@ idna==3.7
cycler
kiwisolver>=1.3.1
matplotlib
numpy>=1.18.5
# numpy 2.4 ships PEP 695 `type` statements in its stubs, which mypy rejects
# under python_version=3.10 (see [tool.mypy] in pyproject.toml). Cap below 2.4,
# matching rf-detr's typing constraint.
numpy>=1.18.5,<2.4
opencv-python-headless==4.10.0.84
Pillow>=7.1.2
# https://github.com/roboflow/roboflow-python/issues/390
Expand Down
34 changes: 27 additions & 7 deletions roboflow/util/model_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,13 @@ def task_of_model_type(model_type: str) -> str:

Non-detect tasks double as the model_type suffix token
(e.g. 'yolov11-seg' -> TASK_SEG). Plain 'yolov11' / 'rfdetr-base' -> TASK_DET.

Keypoint/pose models may spell the token as either 'pose' (Ultralytics) or
'keypoint' (rf-detr, e.g. 'rfdetr-keypoint-preview'); both map to TASK_POSE.
"""
s = model_type.lower()
if "keypoint" in s:
return TASK_POSE
for task in (TASK_SEM, TASK_SEG, TASK_POSE, TASK_CLS, TASK_OBB):
if task in s:
return task
Expand Down Expand Up @@ -317,21 +322,34 @@ def _process_yolo(model_type: str, model_path: str, filename: str) -> tuple[str,
def _detect_rfdetr_task(checkpoint) -> Optional[str]:
"""Detect the training task of an rf-detr checkpoint.

rf-detr currently only supports weight upload for detection and instance
segmentation. Modern checkpoints (rf-detr v1.7+) store the Python class
name at `checkpoint["model_name"]` (e.g. 'RFDETRNano' vs 'RFDETRSegNano');
older checkpoints — including those downloaded from Roboflow — lack that
field but always carry `args.segmentation_head: bool`.
rf-detr supports weight upload for detection, instance segmentation, and
keypoint detection. Modern checkpoints (rf-detr v1.7+) store the Python
class name at `checkpoint["model_name"]` (e.g. 'RFDETRNano' vs
'RFDETRSegNano' vs 'RFDETRKeypointPreview').

The deploy bundle written by rf-detr's `export_for_roboflow` only serialises
`{"model", "args"}` — it drops `model_name` — so detection must also work
from `args`: keypoint checkpoints carry a non-empty `args.num_keypoints_per_class`,
and detection/segmentation checkpoints carry `args.segmentation_head: bool`.
"""
if not isinstance(checkpoint, dict):
return None
model_name = checkpoint.get("model_name")
if isinstance(model_name, str):
return TASK_SEG if TASK_SEG in model_name.lower() else TASK_DET
name = model_name.lower()
if "keypoint" in name:
return TASK_POSE
return TASK_SEG if TASK_SEG in name else TASK_DET
args = checkpoint.get("args")
if args is None:
return None
seg_head = args.get("segmentation_head") if isinstance(args, dict) else getattr(args, "segmentation_head", None)

def _arg(key):
return args.get(key) if isinstance(args, dict) else getattr(args, key, None)

if _arg("num_keypoints_per_class"):
return TASK_POSE
seg_head = _arg("segmentation_head")
if seg_head is True:
return TASK_SEG
if seg_head is False:
Expand All @@ -356,6 +374,8 @@ def _process_rfdetr(model_type: str, model_path: str, filename: str) -> tuple[st
"rfdetr-seg-large",
"rfdetr-seg-xlarge",
"rfdetr-seg-2xlarge",
# Keypoint detection models
"rfdetr-keypoint-preview",
]
if model_type not in _supported_types:
raise ValueError(f"Model type {model_type} not supported. Supported types are {_supported_types}")
Expand Down
14 changes: 14 additions & 0 deletions tests/util/test_model_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def test_segment(self):

def test_pose(self):
self.assertEqual(task_of_model_type("yolov11-pose"), TASK_POSE)
self.assertEqual(task_of_model_type("rfdetr-keypoint-preview"), TASK_POSE)

def test_classify(self):
self.assertEqual(task_of_model_type("yolov11-cls"), TASK_CLS)
Expand Down Expand Up @@ -74,6 +75,19 @@ def test_detection_model_names(self):
for name in ("RFDETRNano", "RFDETRSmall", "RFDETRMedium", "RFDETRLarge", "RFDETRXLarge"):
self.assertEqual(_detect_rfdetr_task({"model_name": name}), TASK_DET, name)

def test_keypoint_model_names(self):
self.assertEqual(_detect_rfdetr_task({"model_name": "RFDETRKeypointPreview"}), TASK_POSE)

def test_keypoint_args_fallback(self):
# The deploy bundle from export_for_roboflow carries `args` but not
# `model_name`; a non-empty `num_keypoints_per_class` marks a keypoint model.
self.assertEqual(_detect_rfdetr_task({"args": SimpleNamespace(num_keypoints_per_class=[0, 17])}), TASK_POSE)
self.assertEqual(_detect_rfdetr_task({"args": {"num_keypoints_per_class": [0, 17]}}), TASK_POSE)
# Empty / absent keypoint schema must NOT be treated as a keypoint model.
self.assertEqual(
_detect_rfdetr_task({"args": {"num_keypoints_per_class": [], "segmentation_head": False}}), TASK_DET
)

def test_segmentation_head_fallback(self):
# Roboflow-hosted rf-detr .pt downloads lack `model_name` but always carry
# `args.segmentation_head`. Cover both namespace and dict shapes.
Expand Down
Loading