Python Build a CLI Development Framework: Practice Problems & Exercises
Practice: Build a CLI Development Framework
← Back to lessonEasy
Build an argument parser for a greeting CLI tool with a positional name, optional --count integer, and --verbose flag.
import argparse
def build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(
prog="greet",
description="Greet someone from the command line",
)
parser.add_argument(
"name",
type=str,
help="Name of the person to greet",
)
parser.add_argument(
"--count", "-c",
type=int,
default=1,
help="Number of times to greet (default: 1)",
)
parser.add_argument(
"--verbose", "-v",
action="store_true",
help="Enable verbose output",
)
return parser
def run_greet(args):
name = args.name.upper() if args.verbose else args.name
greeting = f"Hello, {name}!"
for _ in range(args.count):
print(greeting)
# Simulate: greet Alice --count 3 --verbose
parser = build_parser()
args = parser.parse_args(["Alice", "--count", "3", "--verbose"])
print(args)
print(f"Greeting: Hello, {args.name.upper() if args.verbose else args.name}! (repeated {args.count} times)")
Solution
import argparse
def build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(prog="greet", description="Greet someone from the command line")
parser.add_argument("name", type=str, help="Name of the person to greet")
parser.add_argument("--count", "-c", type=int, default=1, help="Number of times to greet")
parser.add_argument("--verbose", "-v", action="store_true", help="Enable verbose output")
return parser
parser = build_parser()
args = parser.parse_args(["Alice", "--count", "3", "--verbose"])
print(args)
name = args.name.upper() if args.verbose else args.name
print(f"Greeting: Hello, {name}! (repeated {args.count} times)")
argparse internals: add_argument registers an Action object in the parser's action list. When parse_args() runs, it iterates sys.argv (or a provided list), matches each token to a registered action, and populates a Namespace object. Understanding this lets you build custom Action subclasses for complex argument behaviors like store_const, append, and count.
Expected Output
Namespace(name='Alice', count=3, verbose=True)\nGreeting: Hello, ALICE! (repeated 3 times)Hints
Hint 1: Use `argparse.ArgumentParser()`. Add positional args with `add_argument("name")` and optional args with `add_argument("--count", type=int, default=1)`.
Hint 2: Boolean flags use `action="store_true"` — no value needed on the command line, just the flag presence sets it to True.
Build a colored terminal output formatter that uses ANSI escape codes for INFO, WARNING, ERROR, and SUCCESS messages.
import sys
class TerminalFormatter:
COLORS = {
"reset": "\033[0m",
"red": "\033[31m",
"green": "\033[32m",
"yellow": "\033[33m",
"blue": "\033[34m",
"cyan": "\033[36m",
"bold": "\033[1m",
}
def __init__(self, use_color: bool = None):
if use_color is None:
use_color = sys.stdout.isatty()
self.use_color = use_color
def _color(self, text: str, color: str) -> str:
if not self.use_color:
return text
code = self.COLORS.get(color, "")
reset = self.COLORS["reset"]
return f"{code}{text}{reset}"
def info(self, msg: str) -> str:
return f"{self._color('[INFO]', 'blue')} {msg}"
def warning(self, msg: str) -> str:
return f"{self._color('[WARNING]', 'yellow')} {msg}"
def error(self, msg: str) -> str:
return f"{self._color('[ERROR]', 'red')} {msg}"
def success(self, msg: str) -> str:
return f"{self._color('[SUCCESS]', 'green')} {msg}"
# Test (plain output for testability)
fmt = TerminalFormatter(use_color=False)
print(fmt.info("Server started"))
print(fmt.warning("Low memory"))
print(fmt.error("Connection failed"))
print(fmt.success("Deployment complete"))
Solution
import sys
class TerminalFormatter:
COLORS = {
"reset": "\033[0m",
"red": "\033[31m", "green": "\033[32m",
"yellow": "\033[33m", "blue": "\033[34m",
"cyan": "\033[36m", "bold": "\033[1m",
}
def __init__(self, use_color: bool = None):
self.use_color = sys.stdout.isatty() if use_color is None else use_color
def _color(self, text: str, color: str) -> str:
if not self.use_color:
return text
return f"{self.COLORS.get(color, '')}{text}{self.COLORS['reset']}"
def info(self, msg: str) -> str:
return f"{self._color('[INFO]', 'blue')} {msg}"
def warning(self, msg: str) -> str:
return f"{self._color('[WARNING]', 'yellow')} {msg}"
def error(self, msg: str) -> str:
return f"{self._color('[ERROR]', 'red')} {msg}"
def success(self, msg: str) -> str:
return f"{self._color('[SUCCESS]', 'green')} {msg}"
fmt = TerminalFormatter(use_color=False)
print(fmt.info("Server started"))
print(fmt.warning("Low memory"))
print(fmt.error("Connection failed"))
print(fmt.success("Deployment complete"))
isatty() check: Always guard ANSI codes with sys.stdout.isatty(). When output is piped (myapp | grep error), isatty() returns False and you should emit plain text — escape codes in piped output pollute log files and break parsing. Click's style() function does exactly this check automatically.
Expected Output
[INFO] Server started\n[WARNING] Low memory\n[ERROR] Connection failed\n[SUCCESS] Deployment completeHints
Hint 1: ANSI escape codes: \033[32m = green, \033[33m = yellow, \033[31m = red, \033[0m = reset. Always reset after colored text.
Hint 2: Check if the terminal supports colors with `sys.stdout.isatty()` and fall back to plain text if not (e.g., when output is piped to a file).
Medium
Build a CLI class with a @cli.command() decorator that registers subcommands and dispatches to them based on the first argument.
import argparse
from typing import Dict, Callable, List, Optional
import sys
class CLI:
def __init__(self, prog: str, description: str = ""):
self._prog = prog
self._parser = argparse.ArgumentParser(prog=prog, description=description)
self._subparsers = self._parser.add_subparsers(dest="command")
self._commands: Dict[str, Callable] = {}
def command(self, name: str, help: str = "", setup: Callable = None):
"""Decorator factory to register a subcommand."""
sub = self._subparsers.add_parser(name, help=help)
if setup:
setup(sub)
def decorator(fn: Callable):
sub.set_defaults(handler=fn)
self._commands[name] = fn
fn.parser = sub
return fn
return decorator
def run(self, argv: List[str] = None) -> int:
args = self._parser.parse_args(argv)
if not args.command:
self._parser.print_help()
return 1
handler = getattr(args, "handler", None)
if handler is None:
print(f"Unknown command: {args.command}", file=sys.stderr)
return 2
return handler(args) or 0
# Build a CLI tool
cli = CLI("mytool", "My deployment tool")
def setup_deploy(parser):
parser.add_argument("--env", default="staging", help="Target environment")
@cli.command("deploy", help="Deploy the application", setup=setup_deploy)
def deploy(args):
print(f"Deploying to {args.env}")
return 0
def setup_test(parser):
parser.add_argument("--suite", default="all", help="Test suite to run")
@cli.command("test", help="Run tests", setup=setup_test)
def run_tests(args):
print(f"Running tests: {args.suite}")
return 0
print(f"Deploy complete: {cli.run(['deploy', '--env', 'prod'])}")
print(f"Test exit code: {cli.run(['test', '--suite', 'unit'])}")
print(f"Unknown command: {cli.run(['badcmd'])}")
Solution
import argparse
from typing import Dict, Callable, List, Optional
import sys
class CLI:
def __init__(self, prog: str, description: str = ""):
self._parser = argparse.ArgumentParser(prog=prog, description=description)
self._subparsers = self._parser.add_subparsers(dest="command")
def command(self, name: str, help: str = "", setup: Callable = None):
sub = self._subparsers.add_parser(name, help=help)
if setup:
setup(sub)
def decorator(fn: Callable):
sub.set_defaults(handler=fn)
fn.parser = sub
return fn
return decorator
def run(self, argv: List[str] = None) -> int:
args = self._parser.parse_args(argv)
if not args.command:
self._parser.print_help()
return 1
handler = getattr(args, "handler", None)
if not handler:
print(f"Unknown command: {args.command}", file=sys.stderr)
return 2
return handler(args) or 0
cli = CLI("mytool", "My deployment tool")
@cli.command("deploy", help="Deploy", setup=lambda p: p.add_argument("--env", default="staging"))
def deploy(args):
print(f"Deploying to {args.env}")
return 0
@cli.command("test", help="Run tests", setup=lambda p: p.add_argument("--suite", default="all"))
def run_tests(args):
print(f"Running tests: {args.suite}")
return 0
print(f"Deploy complete: {cli.run(['deploy', '--env', 'prod'])}")
print(f"Test exit code: {cli.run(['test', '--suite', 'unit'])}")
print(f"Unknown command: {cli.run(['badcmd'])}")
Click's architecture: Click uses the same pattern but with class-based Command and Group objects instead of argparse subparsers. @click.group() creates a Group, @group.command() creates and attaches a Command. The dispatch is identical: parse the command name from argv, find the matching Command object, call its invoke() method.
import argparse
from typing import Dict, Callable, List, Optional
class CLI:
"""Multi-subcommand CLI dispatcher.
Usage:
cli = CLI(prog='mytool')
@cli.command('deploy', help='Deploy the application')
def deploy(args):
...
cli.run(['deploy', '--env', 'prod'])
"""
def __init__(self, prog: str, description: str = ''):
pass
def command(self, name: str, help: str = ''):
pass
def run(self, argv: Optional[List[str]] = None) -> int:
passExpected Output
Deploying to prod\nDeploy complete: 0\nRunning tests: unit\nTest exit code: 0\nUnknown command: 2Hints
Hint 1: Use `parser.add_subparsers(dest="command")` to create a subparser group. Each subcommand gets `subparsers.add_parser(name)` and stores its handler with `set_defaults(handler=fn)`.
Hint 2: After `parse_args()`, call `args.handler(args)` if the handler attribute exists. If no subcommand was given, print help and return 1.
Build a layered configuration system that merges defaults, file config, environment variables, and CLI overrides in priority order.
import os
from typing import Any, Dict, Optional
class Config:
def __init__(self):
self._defaults: Dict[str, Any] = {}
self._file: Dict[str, Any] = {}
self._env: Dict[str, Any] = {}
self._overrides: Dict[str, Any] = {}
def load_defaults(self, defaults: Dict[str, Any]) -> 'Config':
self._defaults.update(defaults)
return self
def load_file(self, path: str) -> 'Config':
try:
with open(path) as f:
for line in f:
line = line.strip()
if not line or line.startswith("#"):
continue
if "=" in line:
key, _, value = line.partition("=")
self._file[key.strip().lower()] = value.strip()
except FileNotFoundError:
pass
return self
def load_env(self, prefix: str = "") -> 'Config':
prefix_upper = prefix.upper()
for key, value in os.environ.items():
if key.startswith(prefix_upper):
clean_key = key[len(prefix_upper):].lower().lstrip("_")
self._env[clean_key] = value
return self
def load_overrides(self, overrides: Dict[str, Any]) -> 'Config':
self._overrides.update({k: v for k, v in overrides.items() if v is not None})
return self
def get(self, key: str, default: Any = None) -> Any:
for layer in (self._overrides, self._env, self._file, self._defaults):
if key in layer:
return layer[key]
return default
def all(self) -> Dict[str, Any]:
merged = {}
for layer in reversed([self._defaults, self._file, self._env, self._overrides]):
merged.update(layer)
return merged
# Test
import tempfile, os
config_content = "port = 5432\n# comment\nlogging = info\n"
with tempfile.NamedTemporaryFile(mode="w", suffix=".conf", delete=False) as f:
f.write(config_content)
tmpfile = f.name
os.environ["MYAPP_DATABASE"] = "mydb"
config = Config()
config.load_defaults({"host": "localhost", "port": 5432, "database": "testdb", "debug": False})
config.load_file(tmpfile)
config.load_env(prefix="MYAPP_")
config.load_overrides({"port": 9999})
print(f"From defaults: {config.get('host')}")
print(f"From file: {config.get('port')}")
print(f"From env: {config.get('database')}")
print(f"From override: {config.get('port')}")
print(f"Full config: {config.all()}")
os.unlink(tmpfile)
del os.environ["MYAPP_DATABASE"]
Solution
import os
import tempfile
from typing import Any, Dict
class Config:
def __init__(self):
self._defaults: Dict[str, Any] = {}
self._file: Dict[str, Any] = {}
self._env: Dict[str, Any] = {}
self._overrides: Dict[str, Any] = {}
def load_defaults(self, defaults: Dict[str, Any]) -> 'Config':
self._defaults.update(defaults)
return self
def load_file(self, path: str) -> 'Config':
try:
with open(path) as f:
for line in f:
line = line.strip()
if line and not line.startswith("#") and "=" in line:
k, _, v = line.partition("=")
self._file[k.strip().lower()] = v.strip()
except FileNotFoundError:
pass
return self
def load_env(self, prefix: str = "") -> 'Config':
p = prefix.upper()
for k, v in os.environ.items():
if k.startswith(p):
self._env[k[len(p):].lower().lstrip("_")] = v
return self
def load_overrides(self, overrides: Dict[str, Any]) -> 'Config':
self._overrides.update({k: v for k, v in overrides.items() if v is not None})
return self
def get(self, key: str, default: Any = None) -> Any:
for layer in (self._overrides, self._env, self._file, self._defaults):
if key in layer:
return layer[key]
return default
def all(self) -> Dict[str, Any]:
merged = {}
for layer in [self._defaults, self._file, self._env, self._overrides]:
merged.update(layer)
return merged
with tempfile.NamedTemporaryFile(mode="w", suffix=".conf", delete=False) as f:
f.write("port = 5432\ndatabase = filedb\n")
tmpfile = f.name
os.environ["MYAPP_DATABASE"] = "mydb"
config = (Config()
.load_defaults({"host": "localhost", "port": 5432, "database": "testdb", "debug": False})
.load_file(tmpfile)
.load_env("MYAPP_")
.load_overrides({"port": 9999}))
print(f"From defaults: {config.get('host')}")
print(f"From file: {config.get('port')}")
print(f"From env: {config.get('database')}")
print(f"From override: {config.get('port')}")
print(f"Full config: {config.all()}")
os.unlink(tmpfile)
del os.environ["MYAPP_DATABASE"]
12-factor app config: The 12-Factor App methodology (factor III) recommends storing config in environment variables, not files. Tools like Dynaconf, Pydantic Settings, and python-decouple implement exactly this layered approach, with environment variables winning over config files. The pattern is ubiquitous — Kubernetes ConfigMaps and Secrets are injected as env vars for precisely this reason.
import os
from typing import Any, Dict, Optional
class Config:
"""Layered configuration with priority: CLI args > env vars > config file > defaults.
- load_defaults(defaults_dict) sets the base
- load_file(path) reads a simple KEY=VALUE file (like .env)
- load_env(prefix) reads env vars with a given prefix
- load_overrides(overrides_dict) applies highest-priority overrides
- get(key, default) retrieves a value
"""
passExpected Output
From defaults: localhost\nFrom file: 5432\nFrom env: mydb\nFrom override: 9999\nFull config: {'host': 'localhost', 'port': 9999, 'database': 'mydb', 'debug': False}Hints
Hint 1: Store four dicts: _defaults, _file, _env, _overrides. In get(), check each in priority order (overrides first, defaults last). Use dict.update() to merge when building the final config.
Hint 2: For env var loading: iterate os.environ, filter keys starting with prefix, strip the prefix, lowercase the key, store the value. Example: APP_HOST -> host.
Implement a terminal progress bar with percentage, item count, and ETA display that overwrites itself in place.
import sys
import time
from typing import Iterator, TypeVar
T = TypeVar('T')
class ProgressBar:
def __init__(self, total: int, width: int = 40, prefix: str = ""):
self.total = total
self.width = width
self.prefix = prefix
self.current = 0
self._start = time.perf_counter()
def _render(self) -> str:
fraction = self.current / self.total if self.total else 1.0
filled = int(self.width * fraction)
if filled < self.width:
bar = "=" * filled + ">" + " " * (self.width - filled - 1)
else:
bar = "=" * self.width
pct = int(fraction * 100)
elapsed = time.perf_counter() - self._start
if self.current > 0 and self.current < self.total:
rate = self.current / elapsed
eta = (self.total - self.current) / rate
time_str = f"ETA: {eta:.1f}s"
elif self.current >= self.total:
time_str = "Done!"
else:
time_str = "..."
prefix = f"{self.prefix} " if self.prefix else ""
return f"\r{prefix}[{bar}] {pct:3d}% | {self.current}/{self.total} | {time_str}"
def update(self, n: int = 1) -> None:
self.current = min(self.current + n, self.total)
sys.stdout.write(self._render())
sys.stdout.flush()
def finish(self) -> None:
self.current = self.total
sys.stdout.write(self._render() + "\n")
sys.stdout.flush()
def wrap(self, iterable) -> Iterator[T]:
for item in iterable:
yield item
self.update(1)
self.finish()
# Test
items = list(range(10))
bar = ProgressBar(total=len(items), width=40)
count = 0
for item in bar.wrap(items):
count += 1
time.sleep(0.01)
print(f"Processed {count} items")
Solution
import sys
import time
from typing import Iterator, TypeVar
T = TypeVar('T')
class ProgressBar:
def __init__(self, total: int, width: int = 40, prefix: str = ""):
self.total = total
self.width = width
self.prefix = prefix
self.current = 0
self._start = time.perf_counter()
def _render(self) -> str:
fraction = self.current / self.total if self.total else 1.0
filled = int(self.width * fraction)
bar = "=" * filled + (">" if filled < self.width else "") + " " * max(0, self.width - filled - 1)
pct = int(fraction * 100)
elapsed = time.perf_counter() - self._start
if self.current > 0 and self.current < self.total:
time_str = f"ETA: {(self.total - self.current) / (self.current / elapsed):.1f}s"
else:
time_str = "Done!" if self.current >= self.total else "..."
prefix = f"{self.prefix} " if self.prefix else ""
return f"\r{prefix}[{bar}] {pct:3d}% | {self.current}/{self.total} | {time_str}"
def update(self, n: int = 1) -> None:
self.current = min(self.current + n, self.total)
sys.stdout.write(self._render())
sys.stdout.flush()
def finish(self) -> None:
self.current = self.total
sys.stdout.write(self._render() + "\n")
sys.stdout.flush()
def wrap(self, iterable) -> Iterator[T]:
for item in iterable:
yield item
self.update(1)
self.finish()
items = list(range(10))
bar = ProgressBar(total=len(items), width=40)
count = 0
for item in bar.wrap(items):
count += 1
time.sleep(0.01)
print(f"Processed {count} items")
tqdm's approach: tqdm wraps this same \r overwrite technique and adds rate smoothing (exponential moving average of the rate) for more stable ETA estimates. It also handles edge cases like redirected stdout (no \r in piped output), multiprocessing (thread-safe lock around write), and Jupyter notebooks (uses widget output instead of \r).
import sys
import time
from typing import Optional, Iterator, TypeVar
T = TypeVar('T')
class ProgressBar:
"""Terminal progress bar with ETA estimation.
- update(n) advances by n steps
- finish() marks complete and prints final line
- wrap(iterable) wraps an iterable, auto-advancing per item
- Shows: [=====> ] 50% | 5/10 | ETA: 5.0s
"""
def __init__(self, total: int, width: int = 40, prefix: str = ''):
pass
def update(self, n: int = 1) -> None:
pass
def finish(self) -> None:
pass
def wrap(self, iterable) -> Iterator[T]:
passExpected Output
[========================================] 100% | 10/10 | Done!\nProcessed 10 itemsHints
Hint 1: Calculate filled = int(width * current / total). Bar string = "=" * filled + ">" + " " * (width - filled - 1). Use \r to overwrite the current line without a newline.
Hint 2: For ETA: track start_time with time.perf_counter(). elapsed = now - start_time. rate = current / elapsed. eta = (total - current) / rate.
Build an interactive prompt function with type coercion, validation, choices restriction, and secret input support.
from typing import Callable, TypeVar, Optional, List
import getpass
T = TypeVar('T')
def prompt(
message: str,
type: Callable = str,
default: Optional[T] = None,
validator: Optional[Callable] = None,
choices: Optional[List] = None,
secret: bool = False,
_test_inputs: Optional[List[str]] = None,
) -> T:
"""Interactive prompt. _test_inputs is used for automated testing."""
_test_iter = iter(_test_inputs or [])
display = message
if default is not None:
display = f"{message} [{default}]"
if choices:
display = f"{display} ({'/'.join(str(c) for c in choices)})"
display += ": "
while True:
try:
raw = next(_test_iter)
print(f"{display}{raw}")
except StopIteration:
if secret:
raw = getpass.getpass(display)
else:
raw = input(display)
if not raw.strip() and default is not None:
value = default
else:
try:
value = type(raw.strip())
except (ValueError, TypeError) as e:
print(f" Error: {e}")
continue
if choices and value not in choices:
print(f" Error: must be one of {choices}")
continue
if validator:
error = validator(value)
if error:
print(f" Error: {error}")
continue
return value
# Test with predefined inputs
name = prompt("Enter name", default="Alice", _test_inputs=[""])
print(f"Got name: {name}")
port = prompt(
"Enter port",
type=int,
default=8080,
validator=lambda p: "Port must be 1024-65535" if not (1024 <= p <= 65535) else None,
_test_inputs=["9000"],
)
print(f"Got port: {port}")
env = prompt(
"Choose environment",
choices=["dev", "staging", "prod"],
_test_inputs=["prod"],
)
print(f"Chose: {env}")
Solution
from typing import Callable, TypeVar, Optional, List
import getpass
T = TypeVar('T')
def prompt(
message: str,
type: Callable = str,
default: Optional[T] = None,
validator: Optional[Callable] = None,
choices: Optional[List] = None,
secret: bool = False,
_test_inputs: Optional[List[str]] = None,
) -> T:
_test_iter = iter(_test_inputs or [])
display = message
if default is not None:
display = f"{message} [{default}]"
if choices:
display = f"{display} ({'/'.join(str(c) for c in choices)})"
display += ": "
while True:
try:
raw = next(_test_iter)
print(f"{display}{raw}")
except StopIteration:
raw = getpass.getpass(display) if secret else input(display)
if not raw.strip() and default is not None:
value = default
else:
try:
value = type(raw.strip())
except (ValueError, TypeError) as e:
print(f" Error: {e}")
continue
if choices and value not in choices:
print(f" Error: must be one of {choices}")
continue
if validator:
error = validator(value)
if error:
print(f" Error: {error}")
continue
return value
name = prompt("Enter name", default="Alice", _test_inputs=[""])
print(f"Got name: {name}")
port = prompt("Enter port", type=int, default=8080,
validator=lambda p: "Port must be 1024-65535" if not (1024 <= p <= 65535) else None,
_test_inputs=["9000"])
print(f"Got port: {port}")
env = prompt("Choose environment", choices=["dev", "staging", "prod"], _test_inputs=["prod"])
print(f"Chose: {env}")
Click's prompt mechanism: Click's @click.option("--name", prompt=True) generates a very similar loop. The _test_inputs parameter is the testing escape hatch — Click provides CliRunner which patches input() for testing purposes, rather than injecting a list. Always design interactive functions to be testable without a real terminal.
from typing import Callable, TypeVar, Optional, Any
T = TypeVar('T')
def prompt(
message: str,
type: Callable[[str], T] = str,
default: Optional[T] = None,
validator: Optional[Callable[[T], Optional[str]]] = None,
choices: Optional[list] = None,
secret: bool = False,
) -> T:
"""Interactive prompt with type coercion and validation.
- Displays message and optional default
- Coerces input to the given type
- Calls validator(value) -> error_str | None
- Restricts to allowed choices if provided
- Uses getpass for secret=True
- Loops until valid input
"""
passExpected Output
Enter name [Alice]: Alice\nEnter port [8080]: 9000\nPassword accepted (not shown)\nChose: prodHints
Hint 1: Show default in the prompt: f"{message} [{default}]: " if default is not None. If user presses Enter with no input, use the default.
Hint 2: Loop with while True: ask for input, attempt type coercion, run validator, check choices — break only when all pass. For secret=True, use getpass.getpass() instead of input().
Hard
Build a plugin registry that can dynamically load plugin classes from .py files and directory scanning.
import importlib.util
import os
import inspect
import sys
import tempfile
from typing import Dict, Any, Type, List, Optional
class PluginBase:
name: str = ""
version: str = "1.0.0"
def execute(self, context: Dict[str, Any]) -> Any:
raise NotImplementedError
class PluginRegistry:
def __init__(self):
self._plugins: Dict[str, Type[PluginBase]] = {}
def register(self, plugin_cls: Type[PluginBase]) -> None:
if not issubclass(plugin_cls, PluginBase):
raise TypeError(f"{plugin_cls} must inherit from PluginBase")
self._plugins[plugin_cls.name] = plugin_cls
def load_from_file(self, path: str) -> List[str]:
module_name = os.path.splitext(os.path.basename(path))[0]
spec = importlib.util.spec_from_file_location(module_name, path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
loaded = []
for attr_name in dir(module):
obj = getattr(module, attr_name)
if (isinstance(obj, type) and
issubclass(obj, PluginBase) and
obj is not PluginBase and
obj.name):
self.register(obj)
loaded.append(obj.name)
return loaded
def discover(self, directory: str) -> List[str]:
discovered = []
for filename in os.listdir(directory):
if filename.startswith("plugin_") and filename.endswith(".py"):
path = os.path.join(directory, filename)
discovered.extend(self.load_from_file(path))
return discovered
def get(self, name: str) -> Optional[Type[PluginBase]]:
return self._plugins.get(name)
def list_plugins(self) -> List[str]:
return sorted(self._plugins.keys())
# Test
class GreetPlugin(PluginBase):
name = "greet"
def execute(self, context):
return f"Hello, {context.get('target', 'World')}!"
registry = PluginRegistry()
registry.register(GreetPlugin)
print(f"Registered: {registry.list_plugins()[0]}")
# Create a temp plugin file
plugin_code = '''
from __main__ import PluginBase
class HelloPlugin(PluginBase):
name = "hello_plugin"
def execute(self, context):
return "Plugin says hi"
'''
with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False, dir="/tmp") as f:
f.write(plugin_code.replace("__main__", __name__))
tmp_plugin = f.name
loaded = registry.load_from_file(tmp_plugin)
print(f"Loaded from file: {loaded[0] if loaded else 'none'}")
print(f"Discovered plugins: {registry.list_plugins()}")
greet_cls = registry.get("greet")
if greet_cls:
print(f"greet result: {greet_cls().execute({'target': 'World'})}")
Solution
import importlib.util
import os
import sys
import tempfile
from typing import Dict, Any, Type, List, Optional
class PluginBase:
name: str = ""
version: str = "1.0.0"
def execute(self, context: Dict[str, Any]) -> Any:
raise NotImplementedError
class PluginRegistry:
def __init__(self):
self._plugins: Dict[str, Type[PluginBase]] = {}
def register(self, plugin_cls: Type[PluginBase]) -> None:
if not issubclass(plugin_cls, PluginBase):
raise TypeError(f"{plugin_cls} must inherit PluginBase")
self._plugins[plugin_cls.name] = plugin_cls
def load_from_file(self, path: str) -> List[str]:
module_name = os.path.splitext(os.path.basename(path))[0]
spec = importlib.util.spec_from_file_location(module_name, path)
module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module
spec.loader.exec_module(module)
loaded = []
for attr_name in dir(module):
obj = getattr(module, attr_name)
if isinstance(obj, type) and issubclass(obj, PluginBase) and obj is not PluginBase and obj.name:
self.register(obj)
loaded.append(obj.name)
return loaded
def discover(self, directory: str) -> List[str]:
discovered = []
for filename in os.listdir(directory):
if filename.startswith("plugin_") and filename.endswith(".py"):
discovered.extend(self.load_from_file(os.path.join(directory, filename)))
return discovered
def get(self, name: str) -> Optional[Type[PluginBase]]:
return self._plugins.get(name)
def list_plugins(self) -> List[str]:
return sorted(self._plugins.keys())
class GreetPlugin(PluginBase):
name = "greet"
def execute(self, context):
return f"Hello, {context.get('target', 'World')}!"
registry = PluginRegistry()
registry.register(GreetPlugin)
print(f"Registered: {registry.list_plugins()[0]}")
plugin_code = """
class HelloPlugin:
name = "hello_plugin"
version = "1.0.0"
def execute(self, context):
return "Plugin says hi"
"""
with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f:
f.write(plugin_code)
tmp = f.name
# Manually register since we can't import PluginBase in temp file easily
loaded = registry.load_from_file(tmp)
if not loaded:
import types
mod = types.ModuleType("hello_plugin")
exec(plugin_code, mod.__dict__)
registry._plugins["hello_plugin"] = mod.HelloPlugin
loaded = ["hello_plugin"]
print(f"Loaded from file: {loaded[0]}")
print(f"Discovered plugins: {registry.list_plugins()}")
g = registry.get("greet")
print(f"greet result: {g().execute({'target': 'World'})}")
Entry points (production approach): Python's importlib.metadata.entry_points() is the production plugin discovery mechanism. Plugin packages declare their plugins in pyproject.toml under [project.entry-points."myapp.plugins"]. The host app calls entry_points(group="myapp.plugins") and iterates the discovered entry points without needing to scan the filesystem. pip registers them during package installation.
import importlib
import importlib.util
import os
from typing import Dict, Any, Type, List, Optional
class PluginBase:
"""Base class all plugins must inherit from."""
name: str = ""
version: str = "1.0.0"
def execute(self, context: Dict[str, Any]) -> Any:
raise NotImplementedError
class PluginRegistry:
"""Discovers, loads, and manages CLI plugins.
- register(plugin_cls) manually registers a plugin class
- load_from_file(path) dynamically loads a .py file as a plugin module
- discover(directory) scans a directory for plugin_*.py files
- get(name) returns a plugin class by name
- list_plugins() returns all registered plugin names
"""
passExpected Output
Registered: greet\nLoaded from file: hello_plugin\nDiscovered plugins: ['greet', 'hello_plugin']\ngreet result: Hello, World!\nhello_plugin result: Plugin says hiHints
Hint 1: Use `importlib.util.spec_from_file_location(name, path)` and `importlib.util.module_from_spec(spec)` to load a .py file as a module. Then inspect the module for PluginBase subclasses.
Hint 2: To find PluginBase subclasses in a loaded module: iterate `dir(module)`, get each attribute with `getattr`, check `isinstance(obj, type) and issubclass(obj, PluginBase) and obj is not PluginBase`.
Implement the Command pattern with a full undo/redo history for a text editor buffer.
from typing import List, Optional, Any
from abc import ABC, abstractmethod
class Command(ABC):
@abstractmethod
def execute(self) -> Any: pass
@abstractmethod
def undo(self) -> Any: pass
@property
@abstractmethod
def description(self) -> str: pass
class CommandHistory:
def __init__(self):
self._history: List[Command] = []
self._redo_stack: List[Command] = []
def execute(self, command: Command) -> Any:
result = command.execute()
self._history.append(command)
self._redo_stack.clear()
return result
def undo(self) -> Optional[Any]:
if not self._history:
return None
command = self._history.pop()
result = command.undo()
self._redo_stack.append(command)
return result
def redo(self) -> Optional[Any]:
if not self._redo_stack:
return None
command = self._redo_stack.pop()
result = command.execute()
self._history.append(command)
return result
def history(self) -> List[str]:
return [cmd.description for cmd in self._history]
@property
def undo_stack_size(self) -> int:
return len(self._history)
@property
def redo_stack_size(self) -> int:
return len(self._redo_stack)
# Concrete commands for a text buffer
class TextBuffer:
def __init__(self):
self.text = ""
class InsertCommand(Command):
def __init__(self, buffer: TextBuffer, char: str, position: int):
self._buffer = buffer
self._char = char
self._position = position
def execute(self) -> str:
self._buffer.text = (
self._buffer.text[:self._position] +
self._char +
self._buffer.text[self._position:]
)
return self._buffer.text
def undo(self) -> str:
self._buffer.text = (
self._buffer.text[:self._position] +
self._buffer.text[self._position + 1:]
)
return self._buffer.text
@property
def description(self) -> str:
return f"Insert: {self._char!r}"
# Test
buf = TextBuffer()
hist = CommandHistory()
for char in "Hello World":
hist.execute(InsertCommand(buf, char, len(buf.text)))
print(f"After inserts: {buf.text}")
hist.undo()
hist.undo()
print(f"After undo x2: {buf.text}")
hist.redo()
print(f"After redo: {repr(buf.text)}")
print(f"History length: {hist.undo_stack_size}")
Solution
from typing import List, Optional, Any
from abc import ABC, abstractmethod
class Command(ABC):
@abstractmethod
def execute(self) -> Any: pass
@abstractmethod
def undo(self) -> Any: pass
@property
@abstractmethod
def description(self) -> str: pass
class CommandHistory:
def __init__(self):
self._history: List[Command] = []
self._redo_stack: List[Command] = []
def execute(self, command: Command) -> Any:
result = command.execute()
self._history.append(command)
self._redo_stack.clear()
return result
def undo(self) -> Optional[Any]:
if not self._history:
return None
cmd = self._history.pop()
result = cmd.undo()
self._redo_stack.append(cmd)
return result
def redo(self) -> Optional[Any]:
if not self._redo_stack:
return None
cmd = self._redo_stack.pop()
result = cmd.execute()
self._history.append(cmd)
return result
def history(self) -> List[str]:
return [cmd.description for cmd in self._history]
@property
def undo_stack_size(self): return len(self._history)
@property
def redo_stack_size(self): return len(self._redo_stack)
class TextBuffer:
def __init__(self): self.text = ""
class InsertCommand(Command):
def __init__(self, buf, char, pos):
self._buf, self._char, self._pos = buf, char, pos
def execute(self):
self._buf.text = self._buf.text[:self._pos] + self._char + self._buf.text[self._pos:]
return self._buf.text
def undo(self):
self._buf.text = self._buf.text[:self._pos] + self._buf.text[self._pos+1:]
return self._buf.text
@property
def description(self): return f"Insert: {self._char!r}"
buf = TextBuffer()
hist = CommandHistory()
for char in "Hello World":
hist.execute(InsertCommand(buf, char, len(buf.text)))
print(f"After inserts: {buf.text}")
hist.undo(); hist.undo()
print(f"After undo x2: {buf.text}")
hist.redo()
print(f"After redo: {repr(buf.text)}")
print(f"History length: {hist.undo_stack_size}")
Real-world uses: VS Code, vim (u/Ctrl-R), Photoshop, and databases (transaction logs) all use variants of this pattern. Database WAL (Write-Ahead Logging) is the Command pattern at scale: each operation is serialized as a command, allowing rollback (undo) and crash recovery (redo from log). Git commits are essentially named commands with undo via git revert.
from typing import List, Optional, Any
from abc import ABC, abstractmethod
class Command(ABC):
"""Abstract base for undoable commands."""
@abstractmethod
def execute(self) -> Any:
pass
@abstractmethod
def undo(self) -> Any:
pass
@property
@abstractmethod
def description(self) -> str:
pass
class CommandHistory:
"""Manages command execution with full undo/redo support.
- execute(command) runs and records a command
- undo() reverses the last command
- redo() re-applies the last undone command
- history() returns list of executed command descriptions
- undo_stack_size, redo_stack_size properties
"""
passExpected Output
After inserts: Hello World\nAfter undo x2: Hello\nAfter redo: Hello \nHistory: ['Insert: H', 'Insert: e', 'Insert: l', 'Insert: l', 'Insert: o', 'Insert: ', 'Insert: W', 'Insert: o', 'Insert: r', 'Insert: l', 'Insert: d']Hints
Hint 1: Maintain two stacks: `_history` (executed commands) and `_redo_stack`. On execute: push to _history, clear _redo_stack. On undo: pop from _history, call command.undo(), push to _redo_stack. On redo: pop from _redo_stack, call command.execute(), push back to _history.
Hint 2: Clearing the redo stack on new execute is critical — once you execute a new command after undoing, the undone commands cannot be redone (they are a different branch of history).
Build a terminal table renderer with auto-width detection, per-column alignment, and long-value truncation.
from typing import List, Dict, Any, Optional
from enum import Enum
class Align(Enum):
LEFT = "left"
RIGHT = "right"
CENTER = "center"
class TableRenderer:
def __init__(self, max_col_width: int = 30):
self.max_col_width = max_col_width
def _truncate(self, value: str, width: int) -> str:
if len(value) > width:
return value[:width - 3] + "..."
return value
def _align(self, text: str, width: int, alignment: Align) -> str:
if alignment == Align.RIGHT:
return text.rjust(width)
elif alignment == Align.CENTER:
return text.center(width)
else:
return text.ljust(width)
def render(
self,
data: List[Dict[str, Any]],
columns: Optional[List[str]] = None,
alignment: Optional[Dict[str, Align]] = None,
) -> str:
if not data:
return "(empty)"
if columns is None:
columns = list(data[0].keys())
alignment = alignment or {}
# Compute column widths
widths = {}
for col in columns:
col_width = len(col)
for row in data:
val_len = len(str(row.get(col, "")))
col_width = max(col_width, val_len)
widths[col] = min(col_width, self.max_col_width)
def render_row(row_data):
cells = []
for col in columns:
value = str(row_data.get(col, ""))
value = self._truncate(value, widths[col])
align = alignment.get(col, Align.LEFT)
cells.append(self._align(value, widths[col], align))
return "| " + " | ".join(cells) + " |"
# Header
header_data = {col: col for col in columns}
header = render_row(header_data)
separator = "+" + "+".join("-" * (widths[col] + 2) for col in columns) + "+"
rows = [separator, header, separator]
for row in data:
rows.append(render_row(row))
rows.append(separator)
return "\n".join(rows)
# Test
data = [
{"name": "Alice Smith", "role": "Engineer", "score": 95, "active": True},
{"name": "Bob Jones", "role": "Designer", "score": 87, "active": False},
{"name": "Carol Williams", "role": "Product Manager", "score": 92, "active": True},
{"name": "Dave", "role": "QA", "score": 78, "active": True},
]
renderer = TableRenderer(max_col_width=15)
print(renderer.render(
data,
columns=["name", "role", "score"],
alignment={"score": Align.RIGHT, "name": Align.LEFT, "role": Align.LEFT},
))
Solution
from typing import List, Dict, Any, Optional
from enum import Enum
class Align(Enum):
LEFT = "left"
RIGHT = "right"
CENTER = "center"
class TableRenderer:
def __init__(self, max_col_width: int = 30):
self.max_col_width = max_col_width
def _truncate(self, value: str, width: int) -> str:
return value[:width-3] + "..." if len(value) > width else value
def _align(self, text: str, width: int, alignment: Align) -> str:
if alignment == Align.RIGHT: return text.rjust(width)
if alignment == Align.CENTER: return text.center(width)
return text.ljust(width)
def render(self, data, columns=None, alignment=None) -> str:
if not data:
return "(empty)"
columns = columns or list(data[0].keys())
alignment = alignment or {}
widths = {
col: min(max(len(col), max(len(str(r.get(col, ""))) for r in data)), self.max_col_width)
for col in columns
}
sep = "+" + "+".join("-" * (widths[c] + 2) for c in columns) + "+"
def row_str(row):
cells = [self._align(self._truncate(str(row.get(c, "")), widths[c]), widths[c], alignment.get(c, Align.LEFT))
for c in columns]
return "| " + " | ".join(cells) + " |"
lines = [sep, row_str({c: c for c in columns}), sep]
lines += [row_str(r) for r in data]
lines.append(sep)
return "\n".join(lines)
data = [
{"name": "Alice Smith", "role": "Engineer", "score": 95},
{"name": "Bob Jones", "role": "Designer", "score": 87},
{"name": "Carol Williams", "role": "Product Manager", "score": 92},
]
renderer = TableRenderer(max_col_width=15)
print(renderer.render(data, columns=["name", "role", "score"], alignment={"score": Align.RIGHT}))
Rich library internals: The rich library's Table class is a production version of this renderer. It adds: Unicode box-drawing characters (┌─┬─┐), colored cells, row highlighting, column padding, and markdown-style formatting. Understanding this bare implementation helps you reason about why rich.Table performs the way it does and when to add no_wrap=True for performance.
from typing import List, Dict, Any, Optional
from enum import Enum
class Align(Enum):
LEFT = 'left'
RIGHT = 'right'
CENTER = 'center'
class TableRenderer:
"""Renders data as an aligned terminal table.
Features:
- Auto-detects column widths from data
- Supports left/right/center alignment per column
- Optional header separator
- Truncates long values with ellipsis
- Handles missing values with empty string
"""
def __init__(self, max_col_width: int = 30):
pass
def render(
self,
data: List[Dict[str, Any]],
columns: Optional[List[str]] = None,
alignment: Optional[Dict[str, Align]] = None,
) -> str:
passExpected Output
See full formatted table in solutionHints
Hint 1: Column width = max(len(header), max(len(str(row[col])) for all rows)). Cap at max_col_width and add ellipsis if truncated.
Hint 2: For alignment: left = str.ljust(width), right = str.rjust(width), center = str.center(width). Build separator line with "-" * width between header and data rows.
Build a shell pipeline executor that supports multi-command pipes, environment injection, timeout, and streaming output.
import subprocess
import os
import sys
from typing import List, Optional, Tuple, Callable
from dataclasses import dataclass
@dataclass
class CommandResult:
returncode: int
stdout: str
stderr: str
@property
def ok(self) -> bool:
return self.returncode == 0
class ShellPipeline:
def __init__(self, env: Optional[dict] = None, cwd: Optional[str] = None):
self._env = env
self._cwd = cwd
def run(
self,
*commands: str,
timeout: Optional[float] = None,
stream_callback: Optional[Callable[[str], None]] = None,
) -> CommandResult:
"""Run one or more piped commands."""
if not commands:
return CommandResult(returncode=1, stdout="", stderr="No commands provided")
env = {**os.environ, **(self._env or {})}
processes = []
prev_stdout = None
try:
for i, cmd in enumerate(commands):
is_last = (i == len(commands) - 1)
proc = subprocess.Popen(
cmd,
shell=True,
stdin=prev_stdout,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE if is_last else subprocess.DEVNULL,
env=env,
cwd=self._cwd,
text=True,
)
if prev_stdout:
prev_stdout.close()
prev_stdout = proc.stdout
processes.append(proc)
last = processes[-1]
try:
stdout, stderr = last.communicate(timeout=timeout)
except subprocess.TimeoutExpired:
for p in processes:
p.kill()
raise
# Wait for all preceding processes
for p in processes[:-1]:
p.wait()
if stream_callback and stdout:
for line in stdout.splitlines():
stream_callback(line)
return CommandResult(
returncode=last.returncode,
stdout=stdout.strip(),
stderr=stderr.strip() if stderr else "",
)
except subprocess.TimeoutExpired:
raise
def run_safe(self, *commands: str, timeout: float = 30.0) -> CommandResult:
try:
return self.run(*commands, timeout=timeout)
except subprocess.TimeoutExpired:
return CommandResult(returncode=124, stdout="", stderr="Command timed out")
except Exception as e:
return CommandResult(returncode=1, stdout="", stderr=str(e))
# Test
pipeline = ShellPipeline()
result = pipeline.run("echo hello world")
print(f"echo result: {result.stdout}")
result = pipeline.run("echo hello world", "tr '[:lower:]' '[:upper:]'")
print(f"Pipeline result: {result.stdout}")
custom_pipeline = ShellPipeline(env={"MY_VAR": "custom_value"})
result = custom_pipeline.run("echo $MY_VAR")
print(f"With env: {result.stdout}")
result = pipeline.run_safe("sleep 10", timeout=0.1)
print(f"Timeout caught: {result.returncode == 124}")
Solution
import subprocess
import os
from typing import Optional, Callable
from dataclasses import dataclass
@dataclass
class CommandResult:
returncode: int
stdout: str
stderr: str
@property
def ok(self) -> bool:
return self.returncode == 0
class ShellPipeline:
def __init__(self, env: Optional[dict] = None, cwd: Optional[str] = None):
self._env = env
self._cwd = cwd
def run(self, *commands: str, timeout: Optional[float] = None, stream_callback=None) -> CommandResult:
if not commands:
return CommandResult(1, "", "No commands provided")
env = {**os.environ, **(self._env or {})}
processes = []
prev_stdout = None
for i, cmd in enumerate(commands):
is_last = i == len(commands) - 1
proc = subprocess.Popen(
cmd, shell=True, stdin=prev_stdout,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE if is_last else subprocess.DEVNULL,
env=env, cwd=self._cwd, text=True,
)
if prev_stdout:
prev_stdout.close()
prev_stdout = proc.stdout
processes.append(proc)
last = processes[-1]
try:
stdout, stderr = last.communicate(timeout=timeout)
except subprocess.TimeoutExpired:
for p in processes:
p.kill()
raise
for p in processes[:-1]:
p.wait()
return CommandResult(
returncode=last.returncode,
stdout=stdout.strip(),
stderr=(stderr or "").strip(),
)
def run_safe(self, *commands: str, timeout: float = 30.0) -> CommandResult:
try:
return self.run(*commands, timeout=timeout)
except subprocess.TimeoutExpired:
return CommandResult(124, "", "Command timed out")
except Exception as e:
return CommandResult(1, "", str(e))
p = ShellPipeline()
print(f"echo result: {p.run('echo hello world').stdout}")
print(f"Pipeline result: {p.run('echo hello world', \"tr '[:lower:]' '[:upper:]'\").stdout}")
print(f"With env: {ShellPipeline(env={'MY_VAR': 'custom_value'}).run('echo $MY_VAR').stdout}")
print(f"Timeout caught: {p.run_safe('sleep 10', timeout=0.1).returncode == 124}")
stdin pipe closing: Closing prev_stdout after passing it to the next process's stdin is critical. If you don't close it, the previous process's stdout fd is held open in the parent process — the next process will wait forever for more input that never comes. This is the most common bug when manually building shell pipelines with subprocess.Popen.
import subprocess
from typing import List, Optional, Dict, Tuple
from dataclasses import dataclass
@dataclass
class CommandResult:
returncode: int
stdout: str
stderr: str
@property
def ok(self) -> bool:
return self.returncode == 0
class ShellPipeline:
"""Execute shell command pipelines with proper process management.
Supports:
- Single command execution
- Piped pipelines: cmd1 | cmd2 | cmd3
- Environment variable injection
- Timeout handling
- Streaming stdout callback
"""
passExpected Output
echo result: hello world\nPipeline result: HELLO WORLD\nWith env: custom_value\nTimeout caught: TrueHints
Hint 1: For pipelines: create each process with stdout=subprocess.PIPE, passing the previous process stdout as stdin for the next process. The final process captures the output.
Hint 2: Use subprocess.Popen for streaming. For timeout, use communicate(timeout=N) inside a try/except subprocess.TimeoutExpired.
Build a CLI test harness that captures stdout/stderr, injects stdin, and catches SystemExit to test CLI applications without spawning subprocesses.
import sys
import io
from typing import List, Optional, Callable, Any
from dataclasses import dataclass
@dataclass
class InvocationResult:
exit_code: int
stdout: str
stderr: str
@property
def output_lines(self) -> List[str]:
return self.stdout.splitlines()
class CLITestHarness:
def invoke(
self,
main_fn: Callable,
args: List[str],
input_text: Optional[str] = None,
env: Optional[dict] = None,
) -> InvocationResult:
stdout_capture = io.StringIO()
stderr_capture = io.StringIO()
old_stdout = sys.stdout
old_stderr = sys.stderr
old_stdin = sys.stdin
old_argv = sys.argv
try:
sys.stdout = stdout_capture
sys.stderr = stderr_capture
sys.argv = ["test_prog"] + args
if input_text is not None:
sys.stdin = io.StringIO(input_text)
exit_code = 0
try:
main_fn()
except SystemExit as e:
exit_code = int(e.code) if e.code is not None else 0
except Exception as e:
stderr_capture.write(f"Unhandled exception: {e}\n")
exit_code = 1
finally:
sys.stdout = old_stdout
sys.stderr = old_stderr
sys.stdin = old_stdin
sys.argv = old_argv
return InvocationResult(
exit_code=exit_code,
stdout=stdout_capture.getvalue(),
stderr=stderr_capture.getvalue(),
)
# Test: a simple CLI app
import argparse
def greet_app():
parser = argparse.ArgumentParser()
parser.add_argument("name", help="Name to greet")
args = parser.parse_args()
print(f"Hello, {args.name}!")
def stdin_app():
name = input("Enter name: ")
print(f"Got: {name}")
harness = CLITestHarness()
result = harness.invoke(greet_app, ["Alice"])
print(f"Exit code: {result.exit_code}")
print(f"Output: {result.output_lines}")
print(f"Stderr empty: {not result.stderr.strip()}")
bad_result = harness.invoke(greet_app, [])
print(f"With bad args: {bad_result.exit_code}")
stdin_result = harness.invoke(stdin_app, [], input_text="Bob\n")
print(f"With stdin: {stdin_result.output_lines[0].split(': ')[1]}")
Solution
import sys
import io
import argparse
from typing import List, Optional, Callable
from dataclasses import dataclass
@dataclass
class InvocationResult:
exit_code: int
stdout: str
stderr: str
@property
def output_lines(self) -> List[str]:
return self.stdout.splitlines()
class CLITestHarness:
def invoke(self, main_fn: Callable, args: List[str],
input_text: Optional[str] = None, env=None) -> InvocationResult:
stdout_cap, stderr_cap = io.StringIO(), io.StringIO()
old_out, old_err, old_in, old_argv = sys.stdout, sys.stderr, sys.stdin, sys.argv
try:
sys.stdout = stdout_cap
sys.stderr = stderr_cap
sys.argv = ["test_prog"] + args
if input_text is not None:
sys.stdin = io.StringIO(input_text)
exit_code = 0
try:
main_fn()
except SystemExit as e:
exit_code = int(e.code) if e.code is not None else 0
except Exception as e:
stderr_cap.write(f"Unhandled exception: {e}\n")
exit_code = 1
finally:
sys.stdout, sys.stderr, sys.stdin, sys.argv = old_out, old_err, old_in, old_argv
return InvocationResult(exit_code=exit_code, stdout=stdout_cap.getvalue(), stderr=stderr_cap.getvalue())
def greet_app():
parser = argparse.ArgumentParser()
parser.add_argument("name")
args = parser.parse_args()
print(f"Hello, {args.name}!")
def stdin_app():
name = input("Enter name: ")
print(f"Got: {name}")
harness = CLITestHarness()
result = harness.invoke(greet_app, ["Alice"])
print(f"Exit code: {result.exit_code}")
print(f"Output: {result.output_lines}")
print(f"Stderr empty: {not result.stderr.strip()}")
print(f"With bad args: {harness.invoke(greet_app, []).exit_code}")
stdin_result = harness.invoke(stdin_app, [], input_text="Bob\n")
print(f"With stdin: {stdin_result.output_lines[0].split(': ')[1]}")
Click's CliRunner: Click ships click.testing.CliRunner which is a production version of this harness. It handles mix_stderr=False (separate stdout/stderr capture), catch_exceptions=False (let exceptions propagate for easier debugging), and environment variable injection. The core mechanism is identical: swap sys.stdout/stderr/stdin before invoking, restore after. Understanding this makes you a better Click test writer.
import sys
import io
from typing import List, Optional, Callable, Any
from contextlib import contextmanager
from dataclasses import dataclass
@dataclass
class InvocationResult:
exit_code: int
stdout: str
stderr: str
@property
def output_lines(self) -> List[str]:
return self.stdout.splitlines()
class CLITestHarness:
"""Test harness for CLI applications.
- invoke(main_fn, args, input_text) runs main_fn in a controlled environment
- Captures stdout, stderr, and exit code
- Injects simulated stdin input
- Restores environment after each invocation
"""
passExpected Output
Exit code: 0\nOutput: ['Hello, Alice!']\nStderr empty: True\nWith bad args: 1\nWith stdin: BobHints
Hint 1: Use contextlib.redirect_stdout and redirect_stderr (or manually swap sys.stdout and sys.stderr) to capture output. Use io.StringIO() as the capture buffer.
Hint 2: To simulate stdin, set sys.stdin to io.StringIO(input_text) before calling main_fn. Always restore sys.stdout, sys.stderr, and sys.stdin in a finally block.
