Skip to main content

Monte Carlo Tree Search for LLM Reasoning

The Game of Math

In 2016, DeepMind's AlphaGo defeated Lee Sedol, one of the greatest Go players in history. The achievement was remarkable not just for what it accomplished but for how it accomplished it. AlphaGo didn't just play by pattern-matching against millions of games. It searched. It used Monte Carlo Tree Search to explore possible future positions, guided by a neural network value function that estimated which board states were most promising.

The core idea: when faced with a problem with an astronomically large solution space (Go has more possible board states than atoms in the observable universe), intelligent search - exploring promising branches more deeply while pruning hopeless ones - can find good solutions that pure neural network inference cannot.

By 2023, several research groups noticed that mathematical reasoning has a similar structure. A hard math problem has a large search space of possible reasoning paths. Most paths lead to dead ends or wrong answers. A small number of paths lead to correct solutions. The model's neural network can estimate how promising a partial solution is. Why not search?

This is the motivation for applying MCTS to LLM reasoning.


Why This Exists - The Limits of Linear Generation

Standard LLM generation is linear. The model produces tokens left to right, one at a time, with no ability to revise earlier choices once made. For conversational tasks, this is fine. For complex multi-step reasoning where early mistakes propagate through the entire solution, it's a significant limitation.

Consider what happens when a model begins solving a hard algebra problem with an incorrect initial setup - perhaps choosing to approach it with substitution when factoring would work better. Every subsequent step inherits this poor choice. The model may complete a syntactically valid but mathematically wrong solution, with no ability to "go back" to step 1 and try a different approach.

MCTS addresses this by converting the linear generation process into a tree search. Rather than committing to a single reasoning path, the model explores multiple branches in a principled way, allocating more exploration to branches that appear promising, and can discover that an early branch was wrong and avoid extending it further.


Classical MCTS - A Quick Recap

Before applying MCTS to language models, let's review the classical algorithm from game playing.

MCTS operates on a tree where:

  • Each node represents a game state (board position)
  • Each edge represents an action (move)
  • The root is the current state
  • Leaf nodes are either terminal (game over) or unexplored

The algorithm runs iterations of four phases:

Selection: Starting from the root, select child nodes using the UCT (Upper Confidence Bound for Trees) formula until reaching a leaf:

UCT(v)=w(v)n(v)+clnN(v)n(v)\text{UCT}(v) = \frac{w(v)}{n(v)} + c\sqrt{\frac{\ln N(v)}{n(v)}}

where w(v)w(v) is wins at node vv, n(v)n(v) is visits to node vv, N(v)N(v) is visits to parent, and cc is an exploration constant.

Expansion: Add one or more child nodes to the selected leaf.

Simulation (Rollout): From the newly expanded node, play out to a terminal state using a fast (often random) policy.

Backpropagation: Update win/visit counts for all nodes along the path from the expanded node back to the root.

After many iterations, nodes that have been visited many times and have high win rates are identified as good moves.


Adapting MCTS to Language Models

The mapping from game MCTS to language model reasoning:

MCTS ConceptGame PlayingLLM Reasoning
NodeBoard statePartial solution (steps 0..k)
ActionChess/Go moveReasoning step (sentence/paragraph)
Terminal stateEnd of gameComplete solution with final answer
Value functionWin probabilityPRM score / correctness probability
Rollout policyRandom playLanguage model completion
Reward+1 win, -1 lossCorrect/incorrect final answer

The key differences from game MCTS:

The action space is continuous: in Go, there are ~250 legal moves per position. In language generation, there are ~50,000 possible next tokens (or effectively infinite possible next steps). This makes naive MCTS impractical - you can't enumerate all possible next steps.

Solution: use the language model as a generative policy to propose a small number of candidate next steps (kk = 4–16 is typical), rather than enumerating all possible actions.

Rollouts are expensive: in Go, random play is fast. LLM generation is expensive - each step requires a forward pass. A full rollout to answer takes hundreds to thousands of tokens.

Solution: use the PRM as a value function instead of rollouts. Rather than simulating to the end, estimate the value of a partial solution directly from the PRM score. This is analogous to AlphaGo's value network replacing expensive rollouts.

import math
import random
from dataclasses import dataclass, field
from typing import List, Optional, Callable


@dataclass
class MCTSNode:
"""
A node in the MCTS tree for LLM reasoning.

Each node represents a partial solution (list of reasoning steps).
"""
steps: List[str] # Reasoning steps up to this point
parent: Optional['MCTSNode'] = None
children: List['MCTSNode'] = field(default_factory=list)

# MCTS statistics
visits: int = 0
value_sum: float = 0.0
is_terminal: bool = False
terminal_reward: float = 0.0

# Cached PRM score for this node
prm_score: Optional[float] = None

@property
def value(self) -> float:
"""Average value observed at this node."""
if self.visits == 0:
return 0.0
return self.value_sum / self.visits

def uct_score(
self,
exploration_constant: float = 1.0,
parent_visits: Optional[int] = None
) -> float:
"""Upper Confidence Bound for Trees score."""
if self.visits == 0:
return float('inf') # Unexplored nodes have infinite UCT

N = parent_visits if parent_visits else (
self.parent.visits if self.parent else 1
)

exploitation = self.value
exploration = exploration_constant * math.sqrt(math.log(N) / self.visits)

return exploitation + exploration


class LLMMCTSSearcher:
"""
Monte Carlo Tree Search for LLM reasoning.

Uses an LLM as the policy (to generate candidate next steps)
and a PRM as the value function (to score partial solutions).
"""

def __init__(
self,
policy_model,
value_model, # Process Reward Model
outcome_verifier: Callable,
n_candidates_per_expand: int = 4,
exploration_constant: float = 1.0,
max_depth: int = 12,
max_iterations: int = 100,
):
self.policy = policy_model
self.value_fn = value_model
self.verifier = outcome_verifier
self.n_candidates = n_candidates_per_expand
self.c = exploration_constant
self.max_depth = max_depth
self.max_iterations = max_iterations

def search(self, problem: str) -> dict:
"""
Run MCTS to find the best reasoning path for the problem.

Returns the best complete solution found.
"""
# Create root node (empty solution)
root = MCTSNode(steps=[])

best_solution = None
best_reward = 0.0

for iteration in range(self.max_iterations):
# Phase 1: Selection
node = self._select(root)

# Phase 2: Expansion
if not node.is_terminal and len(node.steps) < self.max_depth:
children = self._expand(problem, node)
if children:
node = random.choice(children)

# Phase 3: Evaluation (using PRM instead of rollout)
value = self._evaluate(problem, node)

# Check if this is a terminal (complete solution)
if node.is_terminal:
reward = self.verifier(problem, node.steps)
if reward > best_reward:
best_reward = reward
best_solution = node.steps
value = reward

# Phase 4: Backpropagation
self._backpropagate(node, value)

return {
"best_solution": best_solution,
"best_reward": best_reward,
"total_iterations": self.max_iterations,
"tree_size": self._count_nodes(root),
}

def _select(self, root: MCTSNode) -> MCTSNode:
"""
Walk the tree from root, selecting children by UCT score,
until we reach a leaf node.
"""
node = root
while node.children and not node.is_terminal:
# Select child with highest UCT score
node = max(
node.children,
key=lambda c: c.uct_score(self.c, node.visits)
)
return node

def _expand(self, problem: str, node: MCTSNode) -> List[MCTSNode]:
"""
Expand a leaf node by generating candidate next steps.
"""
# Generate N candidate next steps using the policy
candidate_steps = self.policy.sample_next_steps(
problem=problem,
completed_steps=node.steps,
n=self.n_candidates,
temperature=0.8,
)

children = []
for step in candidate_steps:
if step is None: # None indicates "solution complete"
child = MCTSNode(
steps=node.steps.copy(),
parent=node,
is_terminal=True,
)
else:
child = MCTSNode(
steps=node.steps + [step],
parent=node,
)

node.children.append(child)
children.append(child)

return children

def _evaluate(self, problem: str, node: MCTSNode) -> float:
"""
Evaluate a node using the Process Reward Model.
This replaces the expensive rollout simulation.
"""
if not node.steps:
return 0.5 # Neutral value for root

if node.prm_score is not None:
return node.prm_score

# Get PRM score for the partial solution
step_scores = self.value_fn(problem, node.steps)
# Use minimum step score as the node value
value = min(s.item() for s in step_scores)

node.prm_score = value
return value

def _backpropagate(self, node: MCTSNode, value: float):
"""
Update visit counts and value sums from node back to root.
"""
current = node
while current is not None:
current.visits += 1
current.value_sum += value
current = current.parent

def _count_nodes(self, root: MCTSNode) -> int:
"""Count total nodes in the tree."""
count = 1
for child in root.children:
count += self._count_nodes(child)
return count

AlphaCode 2 - MCTS in Competitive Programming

DeepMind's AlphaCode 2 (Li et al., 2023) provides the most prominent real-world application of MCTS-like search in LLM reasoning. AlphaCode 2 solved competitive programming problems by combining:

  1. A large code generation model as the policy
  2. A scoring model as the value function
  3. Massive sampling followed by clustering and re-ranking

While not strictly MCTS in the classical sense, it uses the same core idea: explore many candidate solutions, guide search with a value function, select the most promising candidates for further evaluation.

The results were dramatic: AlphaCode 2 achieved a Codeforces rating of ~1340, placing it in the top 15% of competitive programmers, compared to AlphaCode 1's ~1200 (top 54%).

The key innovations:

  • Test-case filtering: generated solutions are executed against sample test cases; those that fail are eliminated early (cheap verification)
  • Solution clustering: semantically similar solutions are grouped; select one from each cluster to maximize coverage
  • Final scoring: among surviving candidates, a learned scorer selects the single best submission
def alphacode_style_search(
problem: str,
code_model,
scorer,
test_cases: list,
n_samples: int = 1000,
n_final_solutions: int = 10,
cluster_solutions: bool = True,
) -> dict:
"""
AlphaCode 2-style search for competitive programming.

Generate many solutions, filter with test cases,
cluster, and select best.
"""
# Phase 1: Large-scale sampling
candidate_solutions = []
for _ in range(n_samples):
code = code_model.generate(problem, temperature=0.8)
candidate_solutions.append(code)

print(f"Generated {len(candidate_solutions)} solutions")

# Phase 2: Test case filtering
# Eliminates obviously wrong solutions cheaply
passing_solutions = []
for code in candidate_solutions:
all_pass = True
for test_case in test_cases[:3]: # Use first few test cases for speed
result = execute_code(code, test_case.input)
if result != test_case.expected_output:
all_pass = False
break

if all_pass:
passing_solutions.append(code)

print(f"{len(passing_solutions)} solutions pass sample test cases")

if not passing_solutions:
# Fallback: return best-scored from all solutions
passing_solutions = candidate_solutions[:100]

# Phase 3: Optional clustering
if cluster_solutions and len(passing_solutions) > n_final_solutions:
# Cluster by semantic similarity (e.g., by approach used)
# and select one representative per cluster
clusters = cluster_by_approach(passing_solutions, n_clusters=n_final_solutions)
representative_solutions = [
max(cluster, key=lambda c: scorer.score(problem, c))
for cluster in clusters
]
else:
representative_solutions = passing_solutions[:n_final_solutions * 3]

# Phase 4: Final scoring and selection
scored_solutions = [
(sol, scorer.score(problem, sol))
for sol in representative_solutions
]
scored_solutions.sort(key=lambda x: x[1], reverse=True)

best_solution, best_score = scored_solutions[0]

return {
"best_solution": best_solution,
"score": best_score,
"n_sampled": n_samples,
"n_passed_tests": len(passing_solutions),
"n_final_candidates": len(representative_solutions),
}

Tree-of-Thought - Simplified MCTS for Everyday Use

Yao et al. (2023) introduced Tree-of-Thought (ToT) as a simpler, more practical version of the MCTS idea. Instead of the full MCTS algorithm with UCT, backpropagation, and many iterations, Tree-of-Thought does:

  1. Generate kk candidate "thoughts" (partial solutions or next steps) at each step
  2. Evaluate each thought with either a value model or a self-evaluation prompt
  3. Select the top bb thoughts to expand further (breadth-first, up to some depth dd)
  4. Return the best final solution

Tree-of-Thought is simpler to implement than full MCTS and works well for problems with a clear branching structure. It doesn't require a trained PRM - it can use self-evaluation (asking the model "how promising is this partial solution on a scale of 1–10?") as an approximation.

def tree_of_thought(
problem: str,
model_fn: Callable,
n_thoughts_per_step: int = 4,
beam_width: int = 2,
max_depth: int = 6,
use_self_evaluation: bool = True,
) -> dict:
"""
Tree-of-Thought search.

At each depth level:
1. Generate n_thoughts_per_step candidate next thoughts per beam member
2. Evaluate each (via model self-evaluation or PRM)
3. Keep top beam_width thoughts

Args:
problem: The problem to solve
model_fn: LLM inference function
n_thoughts_per_step: Number of candidate thoughts per beam state
beam_width: Number of beams to maintain
max_depth: Maximum reasoning depth
use_self_evaluation: Use model self-evaluation if no PRM available

Returns:
Best solution found
"""
# Initialize beam with empty state
beam = [{"steps": [], "score": 0.5}]

for depth in range(max_depth):
all_candidates = []

for state in beam:
# Generate candidate next thoughts
for _ in range(n_thoughts_per_step):
next_thought = generate_next_thought(
model_fn=model_fn,
problem=problem,
steps_so_far=state["steps"],
temperature=0.7,
)

if next_thought is None:
# Solution complete
all_candidates.append({
"steps": state["steps"],
"score": state["score"],
"is_complete": True,
})
continue

new_steps = state["steps"] + [next_thought]

if use_self_evaluation:
# Ask the model to evaluate the partial solution
eval_score = self_evaluate_partial_solution(
model_fn=model_fn,
problem=problem,
steps=new_steps,
)
else:
eval_score = 0.5 # Placeholder - use PRM in practice

all_candidates.append({
"steps": new_steps,
"score": eval_score,
"is_complete": False,
})

# Keep top beam_width candidates
all_candidates.sort(key=lambda x: x["score"], reverse=True)
beam = all_candidates[:beam_width]

# Early termination if all beams have complete solutions
if all(c.get("is_complete", False) for c in beam):
break

return {
"best_solution": beam[0]["steps"],
"best_score": beam[0]["score"],
"depth_reached": depth + 1,
}


def self_evaluate_partial_solution(
model_fn: Callable,
problem: str,
steps: List[str],
) -> float:
"""
Ask the model to evaluate how promising a partial solution is.
Returns a score in [0, 1].
"""
steps_text = "\n".join(f"Step {i+1}: {s}" for i, s in enumerate(steps))
eval_prompt = f"""Problem: {problem}

Partial solution so far:
{steps_text}

Evaluate this partial solution:
- Is the reasoning correct so far?
- Does it seem to be making good progress toward a solution?
- Rate the likelihood that continuing from here leads to the correct answer.

Give a score from 0 to 10 where 0 = definitely wrong, 10 = definitely on track.
Respond with just the number."""

response = model_fn(eval_prompt, temperature=0.1)

# Parse the score
import re
match = re.search(r'\b(\d+(?:\.\d+)?)\b', response)
if match:
score = float(match.group(1))
return min(max(score / 10.0, 0.0), 1.0)

return 0.5 # Default if parsing fails

Compute vs. Quality Trade-offs

MCTS for LLMs involves a fundamental compute-quality trade-off. More iterations means better solutions but higher cost:

StrategyCompute CostQuality GainBest For
Single pass1xBaselineEasy tasks, latency-critical
Best-of-N (N=8)8x+15–20%Medium tasks
Tree-of-Thought (d=4, b=2)~32x+25–35%Hard tasks, structured problems
Full MCTS (100 iterations)~200x+40–60%Hardest tasks, offline
o1-style extended thinkingVariableState of artProduction reasoning models

The key insight from Snell et al. (2024): for a given compute budget, it's often better to run MCTS with moderate depth than to run best-of-N with very large N, because MCTS explores the solution space more efficiently by pruning bad branches early.


Why MCTS Is Hard to Deploy in Production

Despite its theoretical elegance, MCTS-based reasoning has significant production challenges:

Latency is fundamentally high: MCTS requires multiple sequential rounds of generation. Even with parallelism within a round, the depth of the search creates irreducible latency. For a 10-deep search with 4 candidates per node, you're looking at 10 serial rounds of generation. At 2 seconds per round, that's 20 seconds minimum.

Memory requirements: maintaining the tree requires storing all partial solutions generated during the search. For hard problems with deep trees, this can be gigabytes of partial solutions.

No standardized API: current LLM APIs (OpenAI, Anthropic) don't expose the right primitives for MCTS - you can't easily ask for "generate from this specific partial context." You have to reconstruct context at each node from scratch.

Value function calibration: the PRM's scores need to be well-calibrated across the entire problem distribution for MCTS to work well. A miscalibrated PRM can lead the search in wrong directions.

The practical path for most production applications: use o1/o3 or DeepSeek-R1, which internalize this search during training and perform it efficiently during inference without exposing the search structure. Reserve explicit MCTS implementation for research contexts and offline batch workloads.


:::danger Common Mistake: Ignoring Cost in MCTS Benchmarks MCTS papers often report accuracy at a fixed number of tokens or API calls. When comparing against best-of-N, ensure you're using the same total compute budget. A fair comparison: MCTS with 100 total nodes vs. best-of-N with 100 samples. Often the difference in accuracy is smaller than papers imply because the MCTS overhead (tree management, multiple rounds of generation) is not fully accounted for. :::

:::warning Self-Evaluation Quality Tree-of-Thought with self-evaluation (asking the model to rate its own partial solutions) is unreliable for the same reasons CoT faithfulness is limited. Models tend to rate their own reasoning highly regardless of actual correctness. Use a separate value model (PRM) for more reliable evaluation. Self-evaluation works acceptably for coarse filtering (eliminating clearly wrong branches) but not for fine-grained ranking. :::

:::tip The Practical Recommendation For most production applications, you do not need to implement MCTS yourself. Use o1/o3 or DeepSeek-R1, which have internalized this search at training time and run it efficiently at inference. Reserve explicit MCTS for: (1) problems where you need to verify intermediate steps (e.g., formal verification), (2) offline batch workloads where latency is not a constraint, (3) research experiments on new reasoning approaches. :::


Interview Questions and Answers

Q1: How does MCTS apply to LLM reasoning? What are the key adaptations needed?

Classical MCTS was designed for game playing with discrete, enumerable actions (chess moves, Go moves). Adapting it to LLMs requires several key changes: (1) The action space is near-infinite (possible next tokens/sentences), so the policy model is used to propose a small number of candidate next steps rather than enumerating all actions. (2) Terminal state rewards (correct/incorrect answer) are expensive to compute, so a process reward model serves as a value function to evaluate partial solutions without completing them. (3) Node representations are natural language steps rather than board positions, requiring different similarity and expansion logic. (4) Rollouts (playing to terminal state) are replaced by PRM-based evaluation, as full rollouts are extremely expensive for language generation.

Q2: What is the difference between Tree-of-Thought and full MCTS?

Tree-of-Thought is a simplified, more practical version of MCTS. Both explore reasoning tree branches and use value estimates to guide exploration. Full MCTS uses UCT selection (balancing exploration and exploitation), proper backpropagation (updating statistics up the full path), and many iterations that can revisit and refine earlier branches. Tree-of-Thought is essentially beam search: at each depth level, generate kk candidates per beam state, evaluate each, keep top bb. It doesn't backpropagate, doesn't balance exploration/exploitation, and is breadth-first rather than iterative. Tree-of-Thought is much simpler to implement and more predictable in latency; full MCTS potentially finds better solutions but is harder to implement, harder to tune, and has variable latency.

Q3: What role does the PRM play in MCTS?

In AlphaGo-style MCTS, a neural value network estimates the win probability from any board position, replacing expensive random rollouts. In LLM MCTS, the PRM serves this role: it estimates the probability that a partial solution (reasoning up to step kk) can be completed to a correct answer. This value estimate drives the UCT selection - nodes with high PRM scores are visited more often (exploitation), while low-visit nodes are still explored for potential surprises (exploration). Without a good PRM, MCTS degenerates to either random search (if you don't prune) or greedy search (if you always pick the top-scored branch). The quality of the PRM is the single biggest determinant of MCTS effectiveness.

Q4: Why is MCTS hard to deploy in production systems?

Five key challenges: (1) Latency - MCTS requires multiple serial rounds of generation; each round adds latency. For deep searches, total latency can be minutes. (2) Memory - the tree stores all partial solutions generated during search, which can require gigabytes for hard problems. (3) API limitations - standard LLM APIs don't expose primitives for efficient tree search; you have to reconstruct full context for each node. (4) Value function calibration - PRM scores must be well-calibrated; a miscalibrated PRM leads the search badly. (5) Hyperparameter sensitivity - the exploration constant cc, number of candidates per expansion, and depth limit significantly affect both quality and cost and require careful tuning per task type.

Q5: Compare AlphaCode 2's approach to the theoretical MCTS algorithm.

AlphaCode 2 uses a pragmatic variant rather than strict MCTS. It does not maintain an explicit search tree with UCT selection and backpropagation. Instead, it does: massive parallel sampling (1,000+ solutions), cheap filtering (test case execution), clustering (reducing to a small representative set), and final scoring (selecting the best). This is closer to "best-of-N with filtering" than MCTS. The key insight: for code generation, test-case execution is a cheap verifier that eliminates most wrong solutions before expensive scoring. This filter is far more powerful than UCT-based pruning because it's based on actual execution rather than estimated value. The lesson: in domains with cheap verifiers, use them first; MCTS is most useful when verification is expensive.

Q6: How would you implement a practical reasoning search system for a production math tutoring application?

A production-appropriate implementation: (1) Use a distilled R1 model (DeepSeek-R1-Distill-Qwen-7B) as the policy - its thinking tokens implement learned internal search. (2) For problems where the R1 model has low confidence, augment with Tree-of-Thought: generate 4 initial approaches, use a simple heuristic (step count, presence of key mathematical terms) to score them, keep top 2. (3) For final answer verification: if you have a symbolic math solver (SymPy), use it to verify the numeric answer. (4) Latency budget: timeout at 30 seconds, return best solution found so far. (5) Caching: cache R1 outputs for identical or semantically similar problems using embedding similarity. This gives you most of the benefit of MCTS (multiple approaches, quality filtering) without the full overhead of a production MCTS implementation.

:::tip 🎮 Interactive Playground

Visualize this concept: Try the Monte Carlo Tree Search for LLM Reasoning demo on the EngineersOfAI Playground - no code required.

:::

© 2026 EngineersOfAI. All rights reserved.