Python Recursion Fundamentals Practice Problems & Exercises
Practice: Recursion Fundamentals
← Back to lessonEasy
Write a recursive function that prints numbers from n down to 1, then prints "Go!". Do not use any loops.
This is the simplest possible recursion pattern: print, then recurse with a smaller value.
def countdown(n):
if n <= 0:
print("Go!")
return
print(n)
countdown(n - 1)
countdown(5)Solution
def countdown(n):
if n <= 0:
print("Go!")
return
print(n)
countdown(n - 1)
Every recursive function needs exactly two parts. The base case (n <= 0) stops the recursion and prints the final message. The recursive case prints the current number and calls countdown with n - 1, which makes progress toward the base case. If you forget the base case, Python will raise RecursionError after approximately 1000 calls.
def countdown(n):
"""Print numbers from n down to 1, then print 'Go!'.
Use recursion — no loops allowed.
Return None.
"""
# TODO: implement
passExpected Output
5
4
3
2
1
Go!Hints
Hint 1: The base case is when n reaches 0 (or less) — print "Go!" and return.
Hint 2: The recursive case prints n, then calls countdown(n - 1).
Implement the classic recursive factorial function. Handle the edge case where n is negative by raising a ValueError.
Test with factorial(0), factorial(1), factorial(5), and factorial(10).
def factorial(n):
if n < 0:
raise ValueError("factorial undefined for negative n")
if n == 0:
return 1
return n * factorial(n - 1)
print(factorial(0))
print(factorial(1))
print(factorial(5))
print(factorial(10))Solution
def factorial(n):
if n < 0:
raise ValueError("factorial undefined for negative n")
if n == 0:
return 1
return n * factorial(n - 1)
The call chain for factorial(5) creates 6 stack frames: factorial(5) through factorial(0). Each frame waits for the one above it to return before it can multiply. The base case n == 0 returns 1, and then each frame multiplies its n by the returned value as the stack unwinds: 1, 1, 2, 6, 24, 120. For large n (above ~996), this will hit Python's recursion limit.
def factorial(n):
"""Compute n! recursively.
factorial(0) = 1 (base case)
factorial(n) = n * factorial(n-1)
Raise ValueError for negative n.
"""
# TODO: implement
passExpected Output
1
1
120
3628800Hints
Hint 1: The base case is n == 0, which returns 1 (by definition, 0! = 1).
Hint 2: The recursive case returns n * factorial(n - 1).
Write a function that computes the sum of a list recursively. Do not use loops or the built-in sum().
This demonstrates the "head + tail" decomposition pattern: process the first element, recurse on the rest.
def recursive_sum(lst):
if not lst:
return 0
return lst[0] + recursive_sum(lst[1:])
print(recursive_sum([1, 2, 3, 4, 5]))
print(recursive_sum([]))
print(recursive_sum([42]))
print(recursive_sum([10, -5, -8]))Solution
def recursive_sum(lst):
if not lst:
return 0
return lst[0] + recursive_sum(lst[1:])
The head-tail pattern splits a list into lst[0] (head) and lst[1:] (tail). Each recursive call processes one element and passes a shorter list forward. For [1, 2, 3], the chain is: 1 + recursive_sum([2, 3]) then 1 + 2 + recursive_sum([3]) then 1 + 2 + 3 + recursive_sum([]) then 1 + 2 + 3 + 0 = 6. Note that lst[1:] creates a new list each time, making this O(n^2) in memory — the iterative sum() is O(n) and preferred in production.
def recursive_sum(lst):
"""Return the sum of all elements in lst.
Use recursion — no loops, no built-in sum().
"""
# TODO: implement
passExpected Output
15
0
42
-3Hints
Hint 1: The base case is an empty list — the sum of nothing is 0.
Hint 2: The recursive case returns the first element plus the sum of the rest: lst[0] + recursive_sum(lst[1:]).
Write a recursive function that computes base raised to exp without using the ** operator or pow(). Assume exp is a non-negative integer.
Test with power(2, 0), power(2, 10), power(3, 4), and power(5, 3).
def power(base, exp):
if exp < 0:
raise ValueError("exp must be non-negative")
if exp == 0:
return 1
return base * power(base, exp - 1)
print(power(2, 0))
print(power(2, 10))
print(power(3, 4))
print(power(5, 3))Solution
def power(base, exp):
if exp < 0:
raise ValueError("exp must be non-negative")
if exp == 0:
return 1
return base * power(base, exp - 1)
This is O(n) in the exponent — it makes exp recursive calls. For power(2, 10), the chain is: 2 * 2 * 2 * ... * 1 with 10 multiplications. A more efficient approach is "exponentiation by squaring" which is O(log n): if exp is even, compute power(base * base, exp // 2). Python's built-in ** uses this fast algorithm, but the linear version here demonstrates the recursive structure clearly.
def power(base, exp):
"""Compute base ** exp recursively.
exp must be a non-negative integer.
Do not use ** or pow().
"""
# TODO: implement
passExpected Output
1
1024
81
125Hints
Hint 1: Base case: any number raised to the power 0 is 1.
Hint 2: Recursive case: base * power(base, exp - 1) reduces exp by 1 each call.
Medium
Implement recursive binary search on a sorted list. Return the index of the target if found, or -1 if not found. Use the low and high parameters to track the current search range.
Test with [1, 3, 5, 7, 9, 11, 13, 15] and targets 7, 6, 1, 15.
def binary_search(arr, target, low=0, high=None):
if high is None:
high = len(arr) - 1
if low > high:
return -1
mid = (low + high) // 2
if arr[mid] == target:
return mid
if arr[mid] < target:
return binary_search(arr, target, mid + 1, high)
return binary_search(arr, target, low, mid - 1)
nums = [1, 3, 5, 7, 9, 11, 13, 15]
print(binary_search(nums, 7))
print(binary_search(nums, 6))
print(binary_search(nums, 1))
print(binary_search(nums, 15))Solution
def binary_search(arr, target, low=0, high=None):
if high is None:
high = len(arr) - 1
if low > high:
return -1
mid = (low + high) // 2
if arr[mid] == target:
return mid
if arr[mid] < target:
return binary_search(arr, target, mid + 1, high)
return binary_search(arr, target, low, mid - 1)
Binary search halves the search space on each call, giving O(log n) time. For a list of 8 elements, the maximum recursion depth is 3 (log2(8) = 3). This is one of the safest uses of recursion in practice — even for a billion elements, the depth is only ~30 frames, well within Python's recursion limit. The two base cases are: target found (arr[mid] == target) and search space exhausted (low > high).
def binary_search(arr, target, low=0, high=None):
"""Search for target in sorted arr.
Return the index if found, -1 if not found.
Use recursion — no loops.
"""
# TODO: implement
passExpected Output
3
-1
0
7Hints
Hint 1: Base case: if low > high, the target is not in the array — return -1.
Hint 2: Compute mid = (low + high) // 2. If arr[mid] == target, return mid. Otherwise recurse on the correct half.
Implement all three standard binary tree traversals: in-order, pre-order, and post-order. Each should return a list of node values.
Test with the tree:
4
/ \
2 6
/ \ / \
1 3 5 7
class TreeNode:
def __init__(self, val, left=None, right=None):
self.val = val
self.left = left
self.right = right
def inorder(node):
if node is None:
return []
return inorder(node.left) + [node.val] + inorder(node.right)
def preorder(node):
if node is None:
return []
return [node.val] + preorder(node.left) + preorder(node.right)
def postorder(node):
if node is None:
return []
return postorder(node.left) + postorder(node.right) + [node.val]
tree = TreeNode(4,
TreeNode(2, TreeNode(1), TreeNode(3)),
TreeNode(6, TreeNode(5), TreeNode(7))
)
print(inorder(tree))
print(preorder(tree))
print(postorder(tree))Solution
def inorder(node):
if node is None:
return []
return inorder(node.left) + [node.val] + inorder(node.right)
def preorder(node):
if node is None:
return []
return [node.val] + preorder(node.left) + preorder(node.right)
def postorder(node):
if node is None:
return []
return postorder(node.left) + postorder(node.right) + [node.val]
The only difference between the three traversals is where you place [node.val] in the concatenation. In-order (left-root-right) produces sorted output for a BST. Pre-order (root-left-right) is useful for serializing a tree. Post-order (left-right-root) is useful for deletion or computing sizes bottom-up. The base case node is None handles both leaves and empty trees. Recursion depth equals tree height — O(log n) for balanced trees, O(n) for degenerate (linear) trees.
class TreeNode:
def __init__(self, val, left=None, right=None):
self.val = val
self.left = left
self.right = right
def inorder(node):
"""Return a list of values in in-order: left, root, right."""
# TODO: implement
pass
def preorder(node):
"""Return a list of values in pre-order: root, left, right."""
# TODO: implement
pass
def postorder(node):
"""Return a list of values in post-order: left, right, root."""
# TODO: implement
passExpected Output
[1, 2, 3, 4, 5, 6, 7]
[4, 2, 1, 3, 6, 5, 7]
[1, 3, 2, 5, 7, 6, 4]Hints
Hint 1: Base case: if node is None, return an empty list.
Hint 2: In-order: recurse left, then current val, then recurse right. Pre-order: current val first. Post-order: current val last.
Write a function that takes an arbitrarily nested list and returns a flat list of all leaf values. The nesting depth can be anything.
Test with [1, [2, [3, 4], 5], 6], [[[[1]]], [[[2]]]], [], and [1, [2, [], [3]], 4].
def flatten(lst):
result = []
for item in lst:
if isinstance(item, list):
result.extend(flatten(item))
else:
result.append(item)
return result
print(flatten([1, [2, [3, 4], 5], 6]))
print(flatten([[[[1]]], [[[2]]]]))
print(flatten([]))
print(flatten([1, [2, [], [3]], 4]))Solution
def flatten(lst):
result = []
for item in lst:
if isinstance(item, list):
result.extend(flatten(item))
else:
result.append(item)
return result
This is a classic recursive pattern for nested data. Each element is either a leaf (base case — append directly) or a sub-list (recursive case — flatten it and extend). The recursion depth equals the maximum nesting level. For [1, [2, [3, [4]]]], the depth is 4. This pattern appears everywhere in production: flattening JSON responses, walking nested configs, and processing tree-structured data. Use isinstance(item, list) rather than type(item) == list to also handle subclasses.
def flatten(lst):
"""Recursively flatten an arbitrarily nested list.
Return a flat list of all leaf values.
Example: [1, [2, [3, 4], 5], 6] -> [1, 2, 3, 4, 5, 6]
"""
# TODO: implement
passExpected Output
[1, 2, 3, 4, 5, 6]
[1, 2]
[]
[1, 2, 3, 4]Hints
Hint 1: Iterate over each item. If it is a list (use isinstance), recurse into it. Otherwise, it is a leaf — append it.
Hint 2: Use extend() to add the flattened sub-list results to your accumulator.
Reverse a string using recursion only. Do not use loops, [::-1] slicing, or the reversed() built-in.
Test with "abcd", "", "a", and "Hello".
def reverse_string(s):
if len(s) <= 1:
return s
return reverse_string(s[1:]) + s[0]
print(reverse_string("abcd"))
print(reverse_string(""))
print(reverse_string("a"))
print(reverse_string("Hello"))Solution
def reverse_string(s):
if len(s) <= 1:
return s
return reverse_string(s[1:]) + s[0]
This follows the same head-tail decomposition as recursive list sum. For "abcd", the call chain is: reverse("bcd") + "a" then reverse("cd") + "b" + "a" then reverse("d") + "c" + "b" + "a" then "d" + "c" + "b" + "a" = "dcba". Each call creates a new string via s[1:] and concatenation, making this O(n^2) in both time and memory. In production, use s[::-1] which is O(n). But this exercise builds intuition about how recursion decomposes sequential data.
def reverse_string(s):
"""Reverse a string recursively.
No loops, no slicing with [::-1], no reversed().
"""
# TODO: implement
passExpected Output
dcba
a
olleHHints
Hint 1: Base case: an empty string or single character is already reversed.
Hint 2: Recursive case: move the first character to the end — reverse_string(s[1:]) + s[0].
Hard
Implement both naive and memoized recursive Fibonacci. The naive version should be O(2^n) and the memoized version O(n) using a hand-rolled dictionary cache (not lru_cache).
Write a compare function that times both versions on n=30 to show the dramatic performance difference.
import time
def fib_naive(n):
if n <= 1:
return n
return fib_naive(n - 1) + fib_naive(n - 2)
def fib_memo(n, cache=None):
if cache is None:
cache = {}
if n <= 1:
return n
if n in cache:
return cache[n]
result = fib_memo(n - 1, cache) + fib_memo(n - 2, cache)
cache[n] = result
return result
# Naive
start = time.perf_counter()
result_naive = fib_naive(30)
naive_time = time.perf_counter() - start
# Memoized
start = time.perf_counter()
result_memo = fib_memo(30)
memo_time = time.perf_counter() - start
print(f"fib_naive(30) = {result_naive}")
print(f"fib_memo(30) = {result_memo}")
print(f"fib_memo(200) = {fib_memo(200)}")Solution
import time
def fib_naive(n):
if n <= 1:
return n
return fib_naive(n - 1) + fib_naive(n - 2)
def fib_memo(n, cache=None):
if cache is None:
cache = {}
if n <= 1:
return n
if n in cache:
return cache[n]
result = fib_memo(n - 1, cache) + fib_memo(n - 2, cache)
cache[n] = result
return result
Naive Fibonacci is O(2^n) because every subproblem is recomputed from scratch. fib_naive(30) makes about 2.7 million calls. fib_naive(50) would make over a trillion. Memoization transforms the exponential call tree into a linear call chain — each unique n is computed exactly once and cached. The memoized version computes fib(200) instantly because it makes exactly 199 unique computations. The cache trades O(n) space for an exponential reduction in time. In production, use functools.lru_cache instead of a hand-rolled dict — it handles thread safety and cache statistics automatically.
import time
def fib_naive(n):
"""Naive recursive Fibonacci — O(2^n)."""
# TODO: implement
pass
def fib_memo(n, cache=None):
"""Memoized recursive Fibonacci — O(n).
Use a dict cache (no lru_cache decorator).
"""
# TODO: implement with hand-rolled memoization
pass
def compare(n):
"""Time both versions and print results."""
# TODO: implement timing comparison
passExpected Output
fib_naive(30) = 832040
fib_memo(30) = 832040
fib_memo(200) = 280571172992510140037611932413038677189525Hints
Hint 1: Naive: if n <= 1 return n, else return fib(n-1) + fib(n-2). This recomputes overlapping subproblems.
Hint 2: Memoized: check if n is in cache before computing. If not, compute and store. Use a default mutable dict or pass cache explicitly.
Implement three recursive tree functions: max_depth, tree_size, and tree_sum. Each follows the same pattern — base case for None, then combine results from left and right subtrees.
Test with a balanced tree, a degenerate (linear) tree, and an empty tree.
class TreeNode:
def __init__(self, val, left=None, right=None):
self.val = val
self.left = left
self.right = right
def max_depth(node):
if node is None:
return 0
return 1 + max(max_depth(node.left), max_depth(node.right))
def tree_size(node):
if node is None:
return 0
return 1 + tree_size(node.left) + tree_size(node.right)
def tree_sum(node):
if node is None:
return 0
return node.val + tree_sum(node.left) + tree_sum(node.right)
# Balanced tree: 4
# / \
# 2 6
# / \ / \
# 1 3 5 7
balanced = TreeNode(4,
TreeNode(2, TreeNode(1), TreeNode(3)),
TreeNode(6, TreeNode(5), TreeNode(7))
)
# Degenerate tree: 1 -> 2 -> 3 -> 4
degenerate = TreeNode(1, None,
TreeNode(2, None,
TreeNode(3, None,
TreeNode(4))))
print(f"depth={max_depth(balanced)}, size={tree_size(balanced)}, sum={tree_sum(balanced)}")
print(f"depth={max_depth(degenerate)}, size={tree_size(degenerate)}, sum={tree_sum(degenerate)}")
print(f"depth={max_depth(None)}, size={tree_size(None)}, sum={tree_sum(None)}")Solution
def max_depth(node):
if node is None:
return 0
return 1 + max(max_depth(node.left), max_depth(node.right))
def tree_size(node):
if node is None:
return 0
return 1 + tree_size(node.left) + tree_size(node.right)
def tree_sum(node):
if node is None:
return 0
return node.val + tree_sum(node.left) + tree_sum(node.right)
All three functions follow the same recursive template: base case returns 0 for None, recursive case combines the current node with results from both children. The only difference is the combining operation: max for depth, + for size and sum. This pattern generalizes to almost any tree computation — finding min/max values, checking balance, validating BST properties. For the degenerate tree (all right children), depth equals size (4), showing how unbalanced trees degrade recursion depth to O(n), making them vulnerable to RecursionError on large inputs.
class TreeNode:
def __init__(self, val, left=None, right=None):
self.val = val
self.left = left
self.right = right
def max_depth(node):
"""Return the maximum depth of the tree.
An empty tree has depth 0. A single node has depth 1.
"""
# TODO: implement
pass
def tree_size(node):
"""Return the total number of nodes in the tree."""
# TODO: implement
pass
def tree_sum(node):
"""Return the sum of all node values."""
# TODO: implement
passExpected Output
depth=3, size=7, sum=28
depth=4, size=4, sum=10
depth=0, size=0, sum=0Hints
Hint 1: Base case for all three: if node is None, return 0.
Hint 2: max_depth: return 1 + max(max_depth(left), max_depth(right)). tree_size: return 1 + size(left) + size(right). tree_sum: return val + sum(left) + sum(right).
Implement merge sort with two functions: merge_sort (recursive splitting) and merge (combining two sorted lists). This is the canonical divide-and-conquer algorithm.
Test with [5, 2, 8, 1, 9, 3], [], [42], and [3, 1, 4, 1, 3, 2].
def merge_sort(arr):
if len(arr) <= 1:
return arr
mid = len(arr) // 2
left = merge_sort(arr[:mid])
right = merge_sort(arr[mid:])
return merge(left, right)
def merge(left, right):
result = []
i = j = 0
while i < len(left) and j < len(right):
if left[i] <= right[j]:
result.append(left[i])
i += 1
else:
result.append(right[j])
j += 1
result.extend(left[i:])
result.extend(right[j:])
return result
print(merge_sort([5, 2, 8, 1, 9, 3]))
print(merge_sort([]))
print(merge_sort([42]))
print(merge_sort([3, 1, 4, 1, 3, 2]))Solution
def merge_sort(arr):
if len(arr) <= 1:
return arr
mid = len(arr) // 2
left = merge_sort(arr[:mid])
right = merge_sort(arr[mid:])
return merge(left, right)
def merge(left, right):
result = []
i = j = 0
while i < len(left) and j < len(right):
if left[i] <= right[j]:
result.append(left[i])
i += 1
else:
result.append(right[j])
j += 1
result.extend(left[i:])
result.extend(right[j:])
return result
Merge sort is O(n log n) in all cases — it always splits in half (log n levels) and merges n elements at each level. The recursion depth is O(log n), so it is safe even for very large lists (a million elements = ~20 frames deep). The merge function uses the two-pointer technique: compare the fronts of both sorted halves, take the smaller one, advance that pointer. The <= comparison makes the sort stable (equal elements keep their original order). This is fundamentally different from quicksort, which has O(n^2) worst case but O(1) extra space.
def merge_sort(arr):
"""Sort arr using recursive merge sort.
Return a new sorted list.
"""
# TODO: implement
pass
def merge(left, right):
"""Merge two sorted lists into one sorted list."""
# TODO: implement
passExpected Output
[1, 2, 3, 5, 8, 9]
[]
[42]
[1, 1, 2, 3, 3, 4]Hints
Hint 1: Base case: a list of 0 or 1 elements is already sorted — return it.
Hint 2: Recursive case: split in half, sort each half recursively, then merge. The merge function uses two pointers to combine two sorted lists.
