Skip to main content

Python Recursion Fundamentals Practice Problems & Exercises

Practice: Recursion Fundamentals

11 problems4 Easy4 Medium3 Hard40-55 min
← Back to lesson

Easy

#1Recursive CountdownEasy
base-caserecursive-case

Write a recursive function that prints numbers from n down to 1, then prints "Go!". Do not use any loops.

This is the simplest possible recursion pattern: print, then recurse with a smaller value.

Python
def countdown(n):
    if n <= 0:
        print("Go!")
        return
    print(n)
    countdown(n - 1)

countdown(5)
Solution
def countdown(n):
if n <= 0:
print("Go!")
return
print(n)
countdown(n - 1)

Every recursive function needs exactly two parts. The base case (n <= 0) stops the recursion and prints the final message. The recursive case prints the current number and calls countdown with n - 1, which makes progress toward the base case. If you forget the base case, Python will raise RecursionError after approximately 1000 calls.

def countdown(n):
    """Print numbers from n down to 1, then print 'Go!'.
    Use recursion — no loops allowed.
    Return None.
    """
    # TODO: implement
    pass
Expected Output
5
4
3
2
1
Go!
Hints

Hint 1: The base case is when n reaches 0 (or less) — print "Go!" and return.

Hint 2: The recursive case prints n, then calls countdown(n - 1).

#2Recursive FactorialEasy
factorialbase-case

Implement the classic recursive factorial function. Handle the edge case where n is negative by raising a ValueError.

Test with factorial(0), factorial(1), factorial(5), and factorial(10).

Python
def factorial(n):
    if n < 0:
        raise ValueError("factorial undefined for negative n")
    if n == 0:
        return 1
    return n * factorial(n - 1)

print(factorial(0))
print(factorial(1))
print(factorial(5))
print(factorial(10))
Solution
def factorial(n):
if n < 0:
raise ValueError("factorial undefined for negative n")
if n == 0:
return 1
return n * factorial(n - 1)

The call chain for factorial(5) creates 6 stack frames: factorial(5) through factorial(0). Each frame waits for the one above it to return before it can multiply. The base case n == 0 returns 1, and then each frame multiplies its n by the returned value as the stack unwinds: 1, 1, 2, 6, 24, 120. For large n (above ~996), this will hit Python's recursion limit.

def factorial(n):
    """Compute n! recursively.
    factorial(0) = 1 (base case)
    factorial(n) = n * factorial(n-1)
    Raise ValueError for negative n.
    """
    # TODO: implement
    pass
Expected Output
1
1
120
3628800
Hints

Hint 1: The base case is n == 0, which returns 1 (by definition, 0! = 1).

Hint 2: The recursive case returns n * factorial(n - 1).

#3Recursive Sum of a ListEasy
recursive-decompositionlist

Write a function that computes the sum of a list recursively. Do not use loops or the built-in sum().

This demonstrates the "head + tail" decomposition pattern: process the first element, recurse on the rest.

Python
def recursive_sum(lst):
    if not lst:
        return 0
    return lst[0] + recursive_sum(lst[1:])

print(recursive_sum([1, 2, 3, 4, 5]))
print(recursive_sum([]))
print(recursive_sum([42]))
print(recursive_sum([10, -5, -8]))
Solution
def recursive_sum(lst):
if not lst:
return 0
return lst[0] + recursive_sum(lst[1:])

The head-tail pattern splits a list into lst[0] (head) and lst[1:] (tail). Each recursive call processes one element and passes a shorter list forward. For [1, 2, 3], the chain is: 1 + recursive_sum([2, 3]) then 1 + 2 + recursive_sum([3]) then 1 + 2 + 3 + recursive_sum([]) then 1 + 2 + 3 + 0 = 6. Note that lst[1:] creates a new list each time, making this O(n^2) in memory — the iterative sum() is O(n) and preferred in production.

def recursive_sum(lst):
    """Return the sum of all elements in lst.
    Use recursion — no loops, no built-in sum().
    """
    # TODO: implement
    pass
Expected Output
15
0
42
-3
Hints

Hint 1: The base case is an empty list — the sum of nothing is 0.

Hint 2: The recursive case returns the first element plus the sum of the rest: lst[0] + recursive_sum(lst[1:]).

#4Recursive Power FunctionEasy
base-caserecursive-reduction

Write a recursive function that computes base raised to exp without using the ** operator or pow(). Assume exp is a non-negative integer.

Test with power(2, 0), power(2, 10), power(3, 4), and power(5, 3).

Python
def power(base, exp):
    if exp < 0:
        raise ValueError("exp must be non-negative")
    if exp == 0:
        return 1
    return base * power(base, exp - 1)

print(power(2, 0))
print(power(2, 10))
print(power(3, 4))
print(power(5, 3))
Solution
def power(base, exp):
if exp < 0:
raise ValueError("exp must be non-negative")
if exp == 0:
return 1
return base * power(base, exp - 1)

This is O(n) in the exponent — it makes exp recursive calls. For power(2, 10), the chain is: 2 * 2 * 2 * ... * 1 with 10 multiplications. A more efficient approach is "exponentiation by squaring" which is O(log n): if exp is even, compute power(base * base, exp // 2). Python's built-in ** uses this fast algorithm, but the linear version here demonstrates the recursive structure clearly.

def power(base, exp):
    """Compute base ** exp recursively.
    exp must be a non-negative integer.
    Do not use ** or pow().
    """
    # TODO: implement
    pass
Expected Output
1
1024
81
125
Hints

Hint 1: Base case: any number raised to the power 0 is 1.

Hint 2: Recursive case: base * power(base, exp - 1) reduces exp by 1 each call.


Medium

#5Recursive Binary SearchMedium
binary-searchdivide-and-conquer

Implement recursive binary search on a sorted list. Return the index of the target if found, or -1 if not found. Use the low and high parameters to track the current search range.

Test with [1, 3, 5, 7, 9, 11, 13, 15] and targets 7, 6, 1, 15.

Python
def binary_search(arr, target, low=0, high=None):
    if high is None:
        high = len(arr) - 1
    if low > high:
        return -1
    mid = (low + high) // 2
    if arr[mid] == target:
        return mid
    if arr[mid] < target:
        return binary_search(arr, target, mid + 1, high)
    return binary_search(arr, target, low, mid - 1)

nums = [1, 3, 5, 7, 9, 11, 13, 15]
print(binary_search(nums, 7))
print(binary_search(nums, 6))
print(binary_search(nums, 1))
print(binary_search(nums, 15))
Solution
def binary_search(arr, target, low=0, high=None):
if high is None:
high = len(arr) - 1
if low > high:
return -1
mid = (low + high) // 2
if arr[mid] == target:
return mid
if arr[mid] < target:
return binary_search(arr, target, mid + 1, high)
return binary_search(arr, target, low, mid - 1)

Binary search halves the search space on each call, giving O(log n) time. For a list of 8 elements, the maximum recursion depth is 3 (log2(8) = 3). This is one of the safest uses of recursion in practice — even for a billion elements, the depth is only ~30 frames, well within Python's recursion limit. The two base cases are: target found (arr[mid] == target) and search space exhausted (low > high).

def binary_search(arr, target, low=0, high=None):
    """Search for target in sorted arr.
    Return the index if found, -1 if not found.
    Use recursion — no loops.
    """
    # TODO: implement
    pass
Expected Output
3
-1
0
7
Hints

Hint 1: Base case: if low > high, the target is not in the array — return -1.

Hint 2: Compute mid = (low + high) // 2. If arr[mid] == target, return mid. Otherwise recurse on the correct half.

#6Recursive Tree TraversalMedium
treeinorderpreorderpostorder

Implement all three standard binary tree traversals: in-order, pre-order, and post-order. Each should return a list of node values.

Test with the tree:

4
/ \
2 6
/ \ / \
1 3 5 7
Python
class TreeNode:
    def __init__(self, val, left=None, right=None):
        self.val = val
        self.left = left
        self.right = right

def inorder(node):
    if node is None:
        return []
    return inorder(node.left) + [node.val] + inorder(node.right)

def preorder(node):
    if node is None:
        return []
    return [node.val] + preorder(node.left) + preorder(node.right)

def postorder(node):
    if node is None:
        return []
    return postorder(node.left) + postorder(node.right) + [node.val]

tree = TreeNode(4,
    TreeNode(2, TreeNode(1), TreeNode(3)),
    TreeNode(6, TreeNode(5), TreeNode(7))
)

print(inorder(tree))
print(preorder(tree))
print(postorder(tree))
Solution
def inorder(node):
if node is None:
return []
return inorder(node.left) + [node.val] + inorder(node.right)

def preorder(node):
if node is None:
return []
return [node.val] + preorder(node.left) + preorder(node.right)

def postorder(node):
if node is None:
return []
return postorder(node.left) + postorder(node.right) + [node.val]

The only difference between the three traversals is where you place [node.val] in the concatenation. In-order (left-root-right) produces sorted output for a BST. Pre-order (root-left-right) is useful for serializing a tree. Post-order (left-right-root) is useful for deletion or computing sizes bottom-up. The base case node is None handles both leaves and empty trees. Recursion depth equals tree height — O(log n) for balanced trees, O(n) for degenerate (linear) trees.

class TreeNode:
    def __init__(self, val, left=None, right=None):
        self.val = val
        self.left = left
        self.right = right

def inorder(node):
    """Return a list of values in in-order: left, root, right."""
    # TODO: implement
    pass

def preorder(node):
    """Return a list of values in pre-order: root, left, right."""
    # TODO: implement
    pass

def postorder(node):
    """Return a list of values in post-order: left, right, root."""
    # TODO: implement
    pass
Expected Output
[1, 2, 3, 4, 5, 6, 7]
[4, 2, 1, 3, 6, 5, 7]
[1, 3, 2, 5, 7, 6, 4]
Hints

Hint 1: Base case: if node is None, return an empty list.

Hint 2: In-order: recurse left, then current val, then recurse right. Pre-order: current val first. Post-order: current val last.

#7Flatten Nested ListsMedium
nested-structuresisinstance

Write a function that takes an arbitrarily nested list and returns a flat list of all leaf values. The nesting depth can be anything.

Test with [1, [2, [3, 4], 5], 6], [[[[1]]], [[[2]]]], [], and [1, [2, [], [3]], 4].

Python
def flatten(lst):
    result = []
    for item in lst:
        if isinstance(item, list):
            result.extend(flatten(item))
        else:
            result.append(item)
    return result

print(flatten([1, [2, [3, 4], 5], 6]))
print(flatten([[[[1]]], [[[2]]]]))
print(flatten([]))
print(flatten([1, [2, [], [3]], 4]))
Solution
def flatten(lst):
result = []
for item in lst:
if isinstance(item, list):
result.extend(flatten(item))
else:
result.append(item)
return result

This is a classic recursive pattern for nested data. Each element is either a leaf (base case — append directly) or a sub-list (recursive case — flatten it and extend). The recursion depth equals the maximum nesting level. For [1, [2, [3, [4]]]], the depth is 4. This pattern appears everywhere in production: flattening JSON responses, walking nested configs, and processing tree-structured data. Use isinstance(item, list) rather than type(item) == list to also handle subclasses.

def flatten(lst):
    """Recursively flatten an arbitrarily nested list.
    Return a flat list of all leaf values.
    Example: [1, [2, [3, 4], 5], 6] -> [1, 2, 3, 4, 5, 6]
    """
    # TODO: implement
    pass
Expected Output
[1, 2, 3, 4, 5, 6]
[1, 2]
[]
[1, 2, 3, 4]
Hints

Hint 1: Iterate over each item. If it is a list (use isinstance), recurse into it. Otherwise, it is a leaf — append it.

Hint 2: Use extend() to add the flattened sub-list results to your accumulator.

#8Recursive String ReversalMedium
stringrecursive-decomposition

Reverse a string using recursion only. Do not use loops, [::-1] slicing, or the reversed() built-in.

Test with "abcd", "", "a", and "Hello".

Python
def reverse_string(s):
    if len(s) <= 1:
        return s
    return reverse_string(s[1:]) + s[0]

print(reverse_string("abcd"))
print(reverse_string(""))
print(reverse_string("a"))
print(reverse_string("Hello"))
Solution
def reverse_string(s):
if len(s) <= 1:
return s
return reverse_string(s[1:]) + s[0]

This follows the same head-tail decomposition as recursive list sum. For "abcd", the call chain is: reverse("bcd") + "a" then reverse("cd") + "b" + "a" then reverse("d") + "c" + "b" + "a" then "d" + "c" + "b" + "a" = "dcba". Each call creates a new string via s[1:] and concatenation, making this O(n^2) in both time and memory. In production, use s[::-1] which is O(n). But this exercise builds intuition about how recursion decomposes sequential data.

def reverse_string(s):
    """Reverse a string recursively.
    No loops, no slicing with [::-1], no reversed().
    """
    # TODO: implement
    pass
Expected Output
dcba

a
olleH
Hints

Hint 1: Base case: an empty string or single character is already reversed.

Hint 2: Recursive case: move the first character to the end — reverse_string(s[1:]) + s[0].


Hard

#9Fibonacci with MemoizationHard
fibonaccimemoizationperformance

Implement both naive and memoized recursive Fibonacci. The naive version should be O(2^n) and the memoized version O(n) using a hand-rolled dictionary cache (not lru_cache).

Write a compare function that times both versions on n=30 to show the dramatic performance difference.

Python
import time

def fib_naive(n):
    if n <= 1:
        return n
    return fib_naive(n - 1) + fib_naive(n - 2)

def fib_memo(n, cache=None):
    if cache is None:
        cache = {}
    if n <= 1:
        return n
    if n in cache:
        return cache[n]
    result = fib_memo(n - 1, cache) + fib_memo(n - 2, cache)
    cache[n] = result
    return result

# Naive
start = time.perf_counter()
result_naive = fib_naive(30)
naive_time = time.perf_counter() - start

# Memoized
start = time.perf_counter()
result_memo = fib_memo(30)
memo_time = time.perf_counter() - start

print(f"fib_naive(30) = {result_naive}")
print(f"fib_memo(30) = {result_memo}")
print(f"fib_memo(200) = {fib_memo(200)}")
Solution
import time

def fib_naive(n):
if n <= 1:
return n
return fib_naive(n - 1) + fib_naive(n - 2)

def fib_memo(n, cache=None):
if cache is None:
cache = {}
if n <= 1:
return n
if n in cache:
return cache[n]
result = fib_memo(n - 1, cache) + fib_memo(n - 2, cache)
cache[n] = result
return result

Naive Fibonacci is O(2^n) because every subproblem is recomputed from scratch. fib_naive(30) makes about 2.7 million calls. fib_naive(50) would make over a trillion. Memoization transforms the exponential call tree into a linear call chain — each unique n is computed exactly once and cached. The memoized version computes fib(200) instantly because it makes exactly 199 unique computations. The cache trades O(n) space for an exponential reduction in time. In production, use functools.lru_cache instead of a hand-rolled dict — it handles thread safety and cache statistics automatically.

import time

def fib_naive(n):
    """Naive recursive Fibonacci — O(2^n)."""
    # TODO: implement
    pass

def fib_memo(n, cache=None):
    """Memoized recursive Fibonacci — O(n).
    Use a dict cache (no lru_cache decorator).
    """
    # TODO: implement with hand-rolled memoization
    pass

def compare(n):
    """Time both versions and print results."""
    # TODO: implement timing comparison
    pass
Expected Output
fib_naive(30) = 832040
fib_memo(30) = 832040
fib_memo(200) = 280571172992510140037611932413038677189525
Hints

Hint 1: Naive: if n <= 1 return n, else return fib(n-1) + fib(n-2). This recomputes overlapping subproblems.

Hint 2: Memoized: check if n is in cache before computing. If not, compute and store. Use a default mutable dict or pass cache explicitly.

#10Recursive Tree Max Depth and SizeHard
treemultiple-returndepth

Implement three recursive tree functions: max_depth, tree_size, and tree_sum. Each follows the same pattern — base case for None, then combine results from left and right subtrees.

Test with a balanced tree, a degenerate (linear) tree, and an empty tree.

Python
class TreeNode:
    def __init__(self, val, left=None, right=None):
        self.val = val
        self.left = left
        self.right = right

def max_depth(node):
    if node is None:
        return 0
    return 1 + max(max_depth(node.left), max_depth(node.right))

def tree_size(node):
    if node is None:
        return 0
    return 1 + tree_size(node.left) + tree_size(node.right)

def tree_sum(node):
    if node is None:
        return 0
    return node.val + tree_sum(node.left) + tree_sum(node.right)

# Balanced tree:       4
#                     / \
#                    2   6
#                   / \ / \
#                  1  3 5  7
balanced = TreeNode(4,
    TreeNode(2, TreeNode(1), TreeNode(3)),
    TreeNode(6, TreeNode(5), TreeNode(7))
)

# Degenerate tree: 1 -> 2 -> 3 -> 4
degenerate = TreeNode(1, None,
    TreeNode(2, None,
        TreeNode(3, None,
            TreeNode(4))))

print(f"depth={max_depth(balanced)}, size={tree_size(balanced)}, sum={tree_sum(balanced)}")
print(f"depth={max_depth(degenerate)}, size={tree_size(degenerate)}, sum={tree_sum(degenerate)}")
print(f"depth={max_depth(None)}, size={tree_size(None)}, sum={tree_sum(None)}")
Solution
def max_depth(node):
if node is None:
return 0
return 1 + max(max_depth(node.left), max_depth(node.right))

def tree_size(node):
if node is None:
return 0
return 1 + tree_size(node.left) + tree_size(node.right)

def tree_sum(node):
if node is None:
return 0
return node.val + tree_sum(node.left) + tree_sum(node.right)

All three functions follow the same recursive template: base case returns 0 for None, recursive case combines the current node with results from both children. The only difference is the combining operation: max for depth, + for size and sum. This pattern generalizes to almost any tree computation — finding min/max values, checking balance, validating BST properties. For the degenerate tree (all right children), depth equals size (4), showing how unbalanced trees degrade recursion depth to O(n), making them vulnerable to RecursionError on large inputs.

class TreeNode:
    def __init__(self, val, left=None, right=None):
        self.val = val
        self.left = left
        self.right = right

def max_depth(node):
    """Return the maximum depth of the tree.
    An empty tree has depth 0. A single node has depth 1.
    """
    # TODO: implement
    pass

def tree_size(node):
    """Return the total number of nodes in the tree."""
    # TODO: implement
    pass

def tree_sum(node):
    """Return the sum of all node values."""
    # TODO: implement
    pass
Expected Output
depth=3, size=7, sum=28
depth=4, size=4, sum=10
depth=0, size=0, sum=0
Hints

Hint 1: Base case for all three: if node is None, return 0.

Hint 2: max_depth: return 1 + max(max_depth(left), max_depth(right)). tree_size: return 1 + size(left) + size(right). tree_sum: return val + sum(left) + sum(right).

#11Recursive Merge SortHard
divide-and-conquersortingmerge-sort

Implement merge sort with two functions: merge_sort (recursive splitting) and merge (combining two sorted lists). This is the canonical divide-and-conquer algorithm.

Test with [5, 2, 8, 1, 9, 3], [], [42], and [3, 1, 4, 1, 3, 2].

Python
def merge_sort(arr):
    if len(arr) <= 1:
        return arr
    mid = len(arr) // 2
    left = merge_sort(arr[:mid])
    right = merge_sort(arr[mid:])
    return merge(left, right)

def merge(left, right):
    result = []
    i = j = 0
    while i < len(left) and j < len(right):
        if left[i] <= right[j]:
            result.append(left[i])
            i += 1
        else:
            result.append(right[j])
            j += 1
    result.extend(left[i:])
    result.extend(right[j:])
    return result

print(merge_sort([5, 2, 8, 1, 9, 3]))
print(merge_sort([]))
print(merge_sort([42]))
print(merge_sort([3, 1, 4, 1, 3, 2]))
Solution
def merge_sort(arr):
if len(arr) <= 1:
return arr
mid = len(arr) // 2
left = merge_sort(arr[:mid])
right = merge_sort(arr[mid:])
return merge(left, right)

def merge(left, right):
result = []
i = j = 0
while i < len(left) and j < len(right):
if left[i] <= right[j]:
result.append(left[i])
i += 1
else:
result.append(right[j])
j += 1
result.extend(left[i:])
result.extend(right[j:])
return result

Merge sort is O(n log n) in all cases — it always splits in half (log n levels) and merges n elements at each level. The recursion depth is O(log n), so it is safe even for very large lists (a million elements = ~20 frames deep). The merge function uses the two-pointer technique: compare the fronts of both sorted halves, take the smaller one, advance that pointer. The <= comparison makes the sort stable (equal elements keep their original order). This is fundamentally different from quicksort, which has O(n^2) worst case but O(1) extra space.

def merge_sort(arr):
    """Sort arr using recursive merge sort.
    Return a new sorted list.
    """
    # TODO: implement
    pass

def merge(left, right):
    """Merge two sorted lists into one sorted list."""
    # TODO: implement
    pass
Expected Output
[1, 2, 3, 5, 8, 9]
[]
[42]
[1, 1, 2, 3, 3, 4]
Hints

Hint 1: Base case: a list of 0 or 1 elements is already sorted — return it.

Hint 2: Recursive case: split in half, sort each half recursively, then merge. The merge function uses two pointers to combine two sorted lists.

© 2026 EngineersOfAI. All rights reserved.