diff --git a/src/fastapi_cli/cli.py b/src/fastapi_cli/cli.py index 4348e599..11d7d124 100644 --- a/src/fastapi_cli/cli.py +++ b/src/fastapi_cli/cli.py @@ -1,4 +1,5 @@ import logging +from importlib.metadata import entry_points as _entry_points from pathlib import Path from typing import Annotated, Any @@ -15,9 +16,7 @@ from .logging import setup_logging from .utils.cli import get_rich_toolkit, get_uvicorn_log_config -app = typer.Typer( - rich_markup_mode="rich", context_settings={"help_option_names": ["-h", "--help"]} -) +app = typer.Typer(rich_markup_mode="rich", context_settings={"help_option_names": ["-h", "--help"]}) logger = logging.getLogger(__name__) @@ -48,6 +47,39 @@ pass +def _cmd_name(cmd_info: Any) -> str | None: + """Return the effective CLI name for a registered Typer command.""" + if cmd_info.name is not None: + return cmd_info.name + if cmd_info.callback is not None: + return cmd_info.callback.__name__.lower().replace("_", "-") + return None + + +def _load_plugins(typer_app: typer.Typer) -> None: + """Load commands registered via the 'fastapi_cli.plugins' entry point group.""" + known: set[str] = {n for ci in typer_app.registered_commands if (n := _cmd_name(ci))} + for ep in _entry_points(group="fastapi_cli.plugins"): + cursor = len(typer_app.registered_commands) + try: + ep.load()(typer_app) + except Exception as e: + logger.warning("Plugin '%s' failed to load: %s", ep.name, e) + continue + collisions = { + n + for ci in typer_app.registered_commands[cursor:] + if (n := _cmd_name(ci)) and n in known + } + if collisions: + logger.warning( + "Plugin '%s' overrides existing command(s): %s", + ep.name, + ", ".join(sorted(collisions)), + ) + known.update(n for ci in typer_app.registered_commands[cursor:] if (n := _cmd_name(ci))) + + def version_callback(value: bool) -> None: if value: print(f"FastAPI CLI version: [green]{__version__}[/green]") @@ -58,9 +90,7 @@ def version_callback(value: bool) -> None: def callback( version: Annotated[ bool | None, - typer.Option( - "--version", help="Show the version and exit.", callback=version_callback - ), + typer.Option("--version", help="Show the version and exit.", callback=version_callback), ] = None, verbose: bool = typer.Option(False, help="Enable verbose output"), ) -> None: @@ -88,9 +118,7 @@ def _get_module_tree(module_paths: list[Path]) -> Tree: tree = root_tree for sub_path in module_paths[1:]: - sub_name = ( - f"🐍 {sub_path.name}" if sub_path.is_file() else f"📁 {sub_path.name}" - ) + sub_name = f"🐍 {sub_path.name}" if sub_path.is_file() else f"📁 {sub_path.name}" tree = tree.add(sub_name) if sub_path.is_dir(): tree.add("[dim]🐍 __init__.py[/dim]") @@ -125,9 +153,7 @@ def _run( if entrypoint and (path or app): toolkit.print_line() - toolkit.print( - "[error]Cannot use --entrypoint together with path or --app arguments" - ) + toolkit.print("[error]Cannot use --entrypoint together with path or --app arguments") toolkit.print_line() raise typer.Exit(code=1) @@ -221,9 +247,7 @@ def _run( port=port, reload=reload, reload_dirs=( - [str(directory.resolve()) for directory in reload_dirs] - if reload_dirs - else None + [str(directory.resolve()) for directory in reload_dirs] if reload_dirs else None ), workers=workers, root_path=root_path, diff --git a/tests/assets/plugins/__init__.py b/tests/assets/plugins/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/assets/plugins/broken.py b/tests/assets/plugins/broken.py new file mode 100644 index 00000000..331730e2 --- /dev/null +++ b/tests/assets/plugins/broken.py @@ -0,0 +1,5 @@ +import typer + + +def register(app: typer.Typer) -> None: + raise RuntimeError("intentionally broken plugin") diff --git a/tests/assets/plugins/colliding.py b/tests/assets/plugins/colliding.py new file mode 100644 index 00000000..0c900b27 --- /dev/null +++ b/tests/assets/plugins/colliding.py @@ -0,0 +1,7 @@ +import typer + + +def register(app: typer.Typer) -> None: + @app.command("dev") # collides with built-in dev command + def dev() -> None: + pass # pragma: no cover diff --git a/tests/assets/plugins/sample.py b/tests/assets/plugins/sample.py new file mode 100644 index 00000000..c0ac5b14 --- /dev/null +++ b/tests/assets/plugins/sample.py @@ -0,0 +1,8 @@ +import typer + + +def register(app: typer.Typer) -> None: + @app.command("ping") + def ping() -> None: + """Test command added by plugin.""" + typer.echo("pong") # pragma: no cover diff --git a/tests/test_cli_plugin.py b/tests/test_cli_plugin.py new file mode 100644 index 00000000..1eada9d7 --- /dev/null +++ b/tests/test_cli_plugin.py @@ -0,0 +1,108 @@ +import sys +from collections.abc import Generator +from importlib.metadata import EntryPoint +from pathlib import Path +from unittest.mock import patch + +import pytest +import typer as _typer +from fastapi_cli.cli import _load_plugins + +PLUGINS_ASSET_PATH: Path = Path(__file__).parent / "assets" + + +@pytest.fixture +def test_app() -> _typer.Typer: + return _typer.Typer() + + +@pytest.fixture +def plugins_on_path() -> Generator[None, None, None]: + original_path = sys.path.copy() + sys.path.insert(0, str(PLUGINS_ASSET_PATH)) + try: + yield + finally: + sys.path[:] = original_path + for key in list(sys.modules.keys()): + if key.startswith("plugins."): + del sys.modules[key] + + +def _entry_point(name: str, module_attr: str) -> EntryPoint: + return EntryPoint( + name=name, + value=f"plugins.{module_attr}", + group="fastapi_cli.plugins", + ) + + +def test_load_plugins_happy_path(plugins_on_path: None, test_app: _typer.Typer) -> None: + """Plugin registers its command on the Typer app.""" + + ep = _entry_point("sample", "sample:register") + + with patch("fastapi_cli.cli._entry_points", return_value=[ep]): + _load_plugins(test_app) + + names = {ci.name for ci in test_app.registered_commands} + assert "ping" in names + + +def test_load_plugins_logs_on_failure(plugins_on_path: None, test_app: _typer.Typer) -> None: + """A plugin that raises is skipped and a warning is logged.""" + + ep = _entry_point("broken", "broken:register") + + with ( + patch("fastapi_cli.cli._entry_points", return_value=[ep]), + patch("fastapi_cli.cli.logger") as mock_logger, + ): + _load_plugins(test_app) + + mock_logger.warning.assert_called_once() + _fmt, ep_name, *_ = mock_logger.warning.call_args.args + assert "broken" in ep_name + + +def test_load_plugins_warns_on_collision_with_builtin( + plugins_on_path: None, test_app: _typer.Typer +) -> None: + """Plugin registering a name already on the app triggers a collision warning.""" + + @test_app.command("dev") + def existing() -> None: + pass # pragma: no cover + + ep = _entry_point("colliding", "colliding:register") + + with ( + patch("fastapi_cli.cli._entry_points", return_value=[ep]), + patch("fastapi_cli.cli.logger") as mock_logger, + ): + _load_plugins(test_app) + + mock_logger.warning.assert_called_once() + _fmt, ep_name, collisions = mock_logger.warning.call_args.args + assert "colliding" in ep_name + assert "dev" in collisions + + +def test_load_plugins_warns_on_cross_plugin_collision( + plugins_on_path: None, test_app: _typer.Typer +) -> None: + """Two plugins registering the same name: only the second gets a warning.""" + + ep_a = _entry_point("sample", "sample:register") # registers "ping" + ep_b = _entry_point("colliding2", "sample:register") # also tries to register "ping" + + with ( + patch("fastapi_cli.cli._entry_points", return_value=[ep_a, ep_b]), + patch("fastapi_cli.cli.logger") as mock_logger, + ): + _load_plugins(test_app) + + assert mock_logger.warning.call_count == 1 + _fmt, ep_name, collisions = mock_logger.warning.call_args.args + assert "colliding2" in ep_name + assert "ping" in collisions