Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions inference/core/registries/roboflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
VersionID,
)
from inference.core.env import (
ALLOW_INFERENCE_MODELS_DIRECTLY_ACCESS_LOCAL_PACKAGES,
CACHE_METADATA_LOCK_TIMEOUT,
LAMBDA,
MODELS_CACHE_AUTH_CACHE_MAX_SIZE,
Expand Down Expand Up @@ -47,6 +48,12 @@
from inference.core.utils.file_system import dump_json, read_json
from inference.core.utils.roboflow import get_model_id_chunks
from inference.models.aliases import resolve_roboflow_model_alias
from inference_models.models.auto_loaders.core import parse_model_config
from inference_models.models.auto_loaders.entities import MODEL_CONFIG_FILE_NAME

# fallback model_type for local `inference_models` packages that do not declare
# model_architecture in model_config.json.
LOCAL_INFERENCE_MODELS_MODEL_TYPE = "inference-models-local"

GENERIC_MODELS = {
"clip": ("embed", "clip"),
Expand Down Expand Up @@ -132,6 +139,8 @@ def _check_if_api_key_has_access_to_model(
service_secret: Optional[str] = None,
) -> bool:
model_id = resolve_roboflow_model_alias(model_id=model_id)
if _get_local_model_type(model_id=model_id) is not None:
return True
dataset_id, version_id = get_model_id_chunks(model_id=model_id)
use_legacy_core_model_auth = (
endpoint_type == ModelEndpointType.CORE_MODEL and dataset_id == "yolo_world"
Expand Down Expand Up @@ -165,6 +174,31 @@ def _check_if_api_key_has_access_to_model(
return True


def _get_local_model_type(model_id: str) -> Optional[Tuple[TaskType, ModelType]]:
"""Returns model metadata read from a local `inference_models` package directory.

Returns None when `model_id` is not a local directory or local loading is disabled,
in which case the regular Roboflow model id resolution applies.
"""
if not (
USE_INFERENCE_MODELS
and ALLOW_INFERENCE_MODELS_DIRECTLY_ACCESS_LOCAL_PACKAGES
and isinstance(model_id, str)
and os.path.isdir(model_id)
):
return None

model_config = parse_model_config(
config_path=os.path.join(model_id, MODEL_CONFIG_FILE_NAME)
)
if model_config.task_type is None:
return None
return (
model_config.task_type,
model_config.model_architecture or LOCAL_INFERENCE_MODELS_MODEL_TYPE,
)


def get_model_type(
model_id: ModelID,
api_key: Optional[str] = None,
Expand All @@ -188,6 +222,9 @@ def get_model_type(
"""

model_id = resolve_roboflow_model_alias(model_id=model_id)
local_model_type = _get_local_model_type(model_id=model_id)
if local_model_type is not None:
return local_model_type
dataset_id, version_id = get_model_id_chunks(model_id=model_id)
# first check if the model id as a whole is in the GENERIC_MODELS dictionary
if model_id in GENERIC_MODELS:
Expand Down
19 changes: 18 additions & 1 deletion inference/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,10 @@
KeypointsDetectionModelStub,
ObjectDetectionModelStub,
)
from inference.core.registries.roboflow import get_model_type
from inference.core.registries.roboflow import (
LOCAL_INFERENCE_MODELS_MODEL_TYPE,
get_model_type,
)
from inference.core.warnings import InferenceModelsStackMissing, ModelDependencyMissing
from inference.models import (
YOLACT,
Expand Down Expand Up @@ -1144,3 +1147,17 @@ def get_roboflow_model(*args, **kwargs):
)

ROBOFLOW_MODEL_TYPES[("vlm", "glm-ocr")] = InferenceModelsGLMOCRAdapter

# Models loaded directly from a local directory (ALLOW_INFERENCE_MODELS_DIRECTLY_ACCESS_LOCAL_PACKAGES).
# Task type is read from the local model_config.json; the adapter forwards the path to
# AutoModel.from_pretrained with allow_direct_local_storage_loading=True.
for local_task, local_adapter in [
("object-detection", InferenceModelsObjectDetectionAdapter),
("instance-segmentation", InferenceModelsInstanceSegmentationAdapter),
("keypoint-detection", InferenceModelsKeyPointsDetectionAdapter),
("classification", InferenceModelsClassificationAdapter),
("semantic-segmentation", InferenceModelsSemanticSegmentationAdapter),
]:
ROBOFLOW_MODEL_TYPES[(local_task, LOCAL_INFERENCE_MODELS_MODEL_TYPE)] = (
local_adapter
)
23 changes: 23 additions & 0 deletions tests/inference/unit_tests/core/registries/test_roboflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,6 +563,29 @@ def test_get_model_type_when_roboflow_api_is_called_for_model_from_new_model_reg
)


@mock.patch.object(roboflow, "ALLOW_INFERENCE_MODELS_DIRECTLY_ACCESS_LOCAL_PACKAGES", True)
@mock.patch.object(roboflow, "USE_INFERENCE_MODELS", True)
def test_get_model_type_for_local_inference_models_package_uses_declared_architecture(
empty_local_dir: str,
) -> None:
# given
with open(os.path.join(empty_local_dir, "model_config.json"), "w") as f:
json.dump(
{
"model_architecture": "depth-anything-v2",
"task_type": "depth-estimation",
"backend_type": "torch",
},
f,
)

# when
result = get_model_type(model_id=empty_local_dir, api_key="my_api_key")

# then
assert result == ("depth-estimation", "depth-anything-v2")


@mock.patch.object(roboflow, "get_model_metadata_from_inference_models_registry")
@mock.patch.object(roboflow, "construct_model_type_cache_path")
@mock.patch.object(roboflow, "USE_INFERENCE_MODELS", True)
Expand Down