Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 39 additions & 15 deletions src/fastapi_cli/cli.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
from importlib.metadata import entry_points as _entry_points
from pathlib import Path
from typing import Annotated, Any

Expand All @@ -15,9 +16,7 @@
from .logging import setup_logging
from .utils.cli import get_rich_toolkit, get_uvicorn_log_config

app = typer.Typer(
rich_markup_mode="rich", context_settings={"help_option_names": ["-h", "--help"]}
)
app = typer.Typer(rich_markup_mode="rich", context_settings={"help_option_names": ["-h", "--help"]})

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -48,6 +47,39 @@
pass


def _cmd_name(cmd_info: Any) -> str | None:
"""Return the effective CLI name for a registered Typer command."""
if cmd_info.name is not None:
return cmd_info.name
if cmd_info.callback is not None:
return cmd_info.callback.__name__.lower().replace("_", "-")
return None


def _load_plugins(typer_app: typer.Typer) -> None:
"""Load commands registered via the 'fastapi_cli.plugins' entry point group."""
known: set[str] = {n for ci in typer_app.registered_commands if (n := _cmd_name(ci))}
for ep in _entry_points(group="fastapi_cli.plugins"):
cursor = len(typer_app.registered_commands)
try:
ep.load()(typer_app)
except Exception as e:
logger.warning("Plugin '%s' failed to load: %s", ep.name, e)
continue
collisions = {
n
for ci in typer_app.registered_commands[cursor:]
if (n := _cmd_name(ci)) and n in known
}
if collisions:
logger.warning(
"Plugin '%s' overrides existing command(s): %s",
ep.name,
", ".join(sorted(collisions)),
)
known.update(n for ci in typer_app.registered_commands[cursor:] if (n := _cmd_name(ci)))


def version_callback(value: bool) -> None:
if value:
print(f"FastAPI CLI version: [green]{__version__}[/green]")
Expand All @@ -58,9 +90,7 @@ def version_callback(value: bool) -> None:
def callback(
version: Annotated[
bool | None,
typer.Option(
"--version", help="Show the version and exit.", callback=version_callback
),
typer.Option("--version", help="Show the version and exit.", callback=version_callback),
] = None,
verbose: bool = typer.Option(False, help="Enable verbose output"),
) -> None:
Expand Down Expand Up @@ -88,9 +118,7 @@ def _get_module_tree(module_paths: list[Path]) -> Tree:

tree = root_tree
for sub_path in module_paths[1:]:
sub_name = (
f"🐍 {sub_path.name}" if sub_path.is_file() else f"📁 {sub_path.name}"
)
sub_name = f"🐍 {sub_path.name}" if sub_path.is_file() else f"📁 {sub_path.name}"
tree = tree.add(sub_name)
if sub_path.is_dir():
tree.add("[dim]🐍 __init__.py[/dim]")
Expand Down Expand Up @@ -125,9 +153,7 @@ def _run(

if entrypoint and (path or app):
toolkit.print_line()
toolkit.print(
"[error]Cannot use --entrypoint together with path or --app arguments"
)
toolkit.print("[error]Cannot use --entrypoint together with path or --app arguments")
toolkit.print_line()
raise typer.Exit(code=1)

Expand Down Expand Up @@ -221,9 +247,7 @@ def _run(
port=port,
reload=reload,
reload_dirs=(
[str(directory.resolve()) for directory in reload_dirs]
if reload_dirs
else None
[str(directory.resolve()) for directory in reload_dirs] if reload_dirs else None
),
workers=workers,
root_path=root_path,
Expand Down
Empty file.
5 changes: 5 additions & 0 deletions tests/assets/plugins/broken.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
import typer


def register(app: typer.Typer) -> None:
raise RuntimeError("intentionally broken plugin")
7 changes: 7 additions & 0 deletions tests/assets/plugins/colliding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import typer


def register(app: typer.Typer) -> None:
@app.command("dev") # collides with built-in dev command
def dev() -> None:
pass # pragma: no cover
8 changes: 8 additions & 0 deletions tests/assets/plugins/sample.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import typer


def register(app: typer.Typer) -> None:
@app.command("ping")
def ping() -> None:
"""Test command added by plugin."""
typer.echo("pong") # pragma: no cover
108 changes: 108 additions & 0 deletions tests/test_cli_plugin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import sys
from collections.abc import Generator
from importlib.metadata import EntryPoint
from pathlib import Path
from unittest.mock import patch

import pytest
import typer as _typer
from fastapi_cli.cli import _load_plugins

PLUGINS_ASSET_PATH: Path = Path(__file__).parent / "assets"


@pytest.fixture
def test_app() -> _typer.Typer:
return _typer.Typer()


@pytest.fixture
def plugins_on_path() -> Generator[None, None, None]:
original_path = sys.path.copy()
sys.path.insert(0, str(PLUGINS_ASSET_PATH))
try:
yield
finally:
sys.path[:] = original_path
for key in list(sys.modules.keys()):
if key.startswith("plugins."):
del sys.modules[key]


def _entry_point(name: str, module_attr: str) -> EntryPoint:
return EntryPoint(
name=name,
value=f"plugins.{module_attr}",
group="fastapi_cli.plugins",
)


def test_load_plugins_happy_path(plugins_on_path: None, test_app: _typer.Typer) -> None:
"""Plugin registers its command on the Typer app."""

ep = _entry_point("sample", "sample:register")

with patch("fastapi_cli.cli._entry_points", return_value=[ep]):
_load_plugins(test_app)

names = {ci.name for ci in test_app.registered_commands}
assert "ping" in names


def test_load_plugins_logs_on_failure(plugins_on_path: None, test_app: _typer.Typer) -> None:
"""A plugin that raises is skipped and a warning is logged."""

ep = _entry_point("broken", "broken:register")

with (
patch("fastapi_cli.cli._entry_points", return_value=[ep]),
patch("fastapi_cli.cli.logger") as mock_logger,
):
_load_plugins(test_app)

mock_logger.warning.assert_called_once()
_fmt, ep_name, *_ = mock_logger.warning.call_args.args
assert "broken" in ep_name


def test_load_plugins_warns_on_collision_with_builtin(
plugins_on_path: None, test_app: _typer.Typer
) -> None:
"""Plugin registering a name already on the app triggers a collision warning."""

@test_app.command("dev")
def existing() -> None:
pass # pragma: no cover

ep = _entry_point("colliding", "colliding:register")

with (
patch("fastapi_cli.cli._entry_points", return_value=[ep]),
patch("fastapi_cli.cli.logger") as mock_logger,
):
_load_plugins(test_app)

mock_logger.warning.assert_called_once()
_fmt, ep_name, collisions = mock_logger.warning.call_args.args
assert "colliding" in ep_name
assert "dev" in collisions


def test_load_plugins_warns_on_cross_plugin_collision(
plugins_on_path: None, test_app: _typer.Typer
) -> None:
"""Two plugins registering the same name: only the second gets a warning."""

ep_a = _entry_point("sample", "sample:register") # registers "ping"
ep_b = _entry_point("colliding2", "sample:register") # also tries to register "ping"

with (
patch("fastapi_cli.cli._entry_points", return_value=[ep_a, ep_b]),
patch("fastapi_cli.cli.logger") as mock_logger,
):
_load_plugins(test_app)

assert mock_logger.warning.call_count == 1
_fmt, ep_name, collisions = mock_logger.warning.call_args.args
assert "colliding2" in ep_name
assert "ping" in collisions
Loading