How to visualize a decision tree in Python
Learn how to visualize a decision tree in Python. Explore different methods, tips, real-world applications, and common error debugging.

Decision tree visualization is a key step for the interpretation of machine learning models. Python provides several libraries that simplify this process and make model logic transparent.
In this article, we'll cover several techniques to plot decision trees using popular libraries. We will also share practical tips for clear visuals, explore real world applications, and offer advice to debug your models effectively.
Using plot_tree function from scikit-learn
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier, plot_tree
import matplotlib.pyplot as plt
iris = load_iris()
clf = DecisionTreeClassifier(max_depth=2).fit(iris.data, iris.target)
plt.figure(figsize=(10,6))
plot_tree(clf, filled=True)
plt.show()--OUTPUT--[A visualization of a decision tree with colored nodes showing the class distribution at each node. The tree has a depth of 2 with the root node splitting into two branches, each with child nodes.]
This example starts by training a DecisionTreeClassifier on the Iris dataset. We deliberately limit the tree's complexity with max_depth=2. This ensures the final visualization is clean and easy to interpret, as deeper trees can quickly become too dense to follow.
The plot_tree function then handles the visualization. Here’s what’s happening with its key parameters:
- It takes the trained model,
clf, as its primary input to generate the plot. - The
filled=Trueargument colors the nodes to represent the majority class at each split, helping you grasp the decision process at a glance.
Basic customization techniques
Beyond the basics of plot_tree, you can unlock more powerful customizations and even interactive plots with libraries like Graphviz, matplotlib, and the specialized dtreeviz.
Enhancing trees with export_graphviz and Graphviz
from sklearn.tree import export_graphviz
import graphviz
dot_data = export_graphviz(clf, out_file=None, feature_names=iris.feature_names,
class_names=iris.target_names, filled=True)
graph = graphviz.Source(dot_data)
graph--OUTPUT--[A more polished tree visualization with feature names (like "petal length", "petal width") at each decision node, class names at leaf nodes, and color-coded nodes representing different classes.]
For a more descriptive plot, you can use export_graphviz. This function translates your trained tree into the DOT graph description language. The graphviz library then uses this data to render a high-quality visual.
- The
feature_namesparameter labels the decision nodes with the actual feature names, like "petal width." class_namesadds the target names to the leaf nodes, showing the final classification.
This combination adds crucial context that the basic plot lacks, making your tree’s logic much clearer.
Customizing tree appearance with matplotlib
plt.figure(figsize=(12,8))
plot_tree(clf, feature_names=iris.feature_names, class_names=iris.target_names,
filled=True, fontsize=10, precision=2)
plt.title("Decision Tree for Iris Dataset")
plt.show()--OUTPUT--[A customized decision tree visualization with the same structure as before but with feature names displayed at decision nodes, class names at leaf nodes, and a title "Decision Tree for Iris Dataset".]
You can also fine-tune the tree’s appearance directly within `matplotlib`. The `plot_tree` function integrates smoothly with `matplotlib`'s customization options, giving you granular control over the final image.
This approach lets you adjust aesthetics with familiar arguments:
fontsize=10controls the text size inside each node.precision=2rounds the values, making the plot cleaner.
You can also use standard `matplotlib` functions like plt.title() to add a title, providing context for your visualization.
Interactive visualization with dtreeviz
from dtreeviz.trees import dtreeviz
viz = dtreeviz(clf, iris.data, iris.target, target_name='variety',
feature_names=iris.feature_names, class_names=list(iris.target_names))
viz.view()--OUTPUT--[An interactive visualization showing the decision tree with histograms or density plots at each node displaying the distribution of samples. Nodes are color-coded and feature splits are clearly shown with numerical thresholds.]
For a more intuitive approach, the dtreeviz library creates rich, interactive visualizations. The core function, dtreeviz(), goes beyond static diagrams by plotting the actual distribution of data points at each node. This gives you a much clearer picture of how the model is separating classes at every split.
- It uses familiar parameters like
feature_namesandclass_namesfor labeling. - The final call to
viz.view()renders the interactive plot.
Advanced visualization techniques
Once you're comfortable with basic tree plots, you can tackle more advanced visualizations for feature importance, complex trees, and even entire random forests.
Displaying feature importance alongside the tree
fig, axes = plt.subplots(1, 2, figsize=(15, 5))
plot_tree(clf, ax=axes[0], filled=True)
axes[1].barh(range(len(iris.feature_names)), clf.feature_importances_)
axes[1].set_yticks(range(len(iris.feature_names)))
axes[1].set_yticklabels(iris.feature_names)
plt.show()--OUTPUT--[A side-by-side visualization with the decision tree on the left and a horizontal bar chart on the right showing the importance of each feature. The bars are ordered by feature name with longer bars indicating higher importance.]
Visualizing feature importance alongside the tree gives you a more complete story of your model's logic. This code uses matplotlib's plt.subplots(1, 2) function to create a side-by-side layout for two plots.
- The left plot, drawn on
ax=axes[0], is the decision tree itself. - The right plot is a horizontal bar chart that ranks how influential each feature is. This data comes directly from the classifier's
feature_importances_attribute.
This combination helps you quickly connect the tree's splits to the most predictive features in your dataset.
Visualization with pydotplus for complex trees
import pydotplus
from sklearn.tree import export_graphviz
dot_data = export_graphviz(clf, out_file=None, feature_names=iris.feature_names,
class_names=iris.target_names, rounded=True, proportion=True)
graph = pydotplus.graph_from_dot_data(dot_data)
graph.write_png('decision_tree_pydotplus.png')--OUTPUT--[The decision tree has been exported as a PNG file with rounded nodes. The visualization shows proportions of samples in each node and has a cleaner look with rounded rectangles instead of sharp boxes.]
When dealing with larger trees, pydotplus provides a robust way to generate and save your visualizations. It uses the same export_graphviz function but lets you add parameters that improve clarity, especially for dense plots.
rounded=Truegives the nodes a cleaner look with rounded corners.proportion=Truedisplays the percentage of samples for each class instead of raw counts.
The resulting DOT data is then processed by pydotplus.graph_from_dot_data(), and you can save the final image directly using graph.write_png().
Visualizing multiple trees in a random forest
from sklearn.ensemble import RandomForestClassifier
forest = RandomForestClassifier(n_estimators=3, max_depth=2).fit(iris.data, iris.target)
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
for i, ax in enumerate(axes):
plot_tree(forest.estimators_[i], filled=True, ax=ax)
ax.set_title(f"Tree {i+1}")
plt.show()--OUTPUT--[A figure showing three decision trees side by side, each representing one tree from a random forest model. Each tree has different structure but same maximum depth of 2. Trees are labeled as "Tree 1", "Tree 2", and "Tree 3".]
A random forest is an ensemble of many decision trees, and visualizing them helps you understand how the model works under the hood. This code trains a RandomForestClassifier with three individual trees by setting n_estimators=3. It then iterates through each tree to plot them side by side.
- The
forest.estimators_attribute gives you access to each trained tree in the forest. - A
forloop plots each tree on a separatematplotlibsubplot usingplot_tree. - This shows how each tree learns slightly different patterns from the data.
Move faster with Replit
Replit is an AI-powered development platform where all Python dependencies come pre-installed, so you can skip setup and start coding instantly. You can immediately apply the techniques you've just learned without worrying about environment configuration.
Instead of piecing together visualization techniques, you can use Agent 4 to build a complete application from a simple description. Describe the app you want to build, and the Agent will take it from idea to working product. For example:
- A model interpretability dashboard that visualizes a decision tree and its corresponding
feature_importances_chart for any uploaded dataset. - A pruning simulator where you can adjust parameters like
max_depthand instantly see how the decision tree visualization changes. - A random forest explorer that plots several individual trees from an ensemble model to show how they differ.
Simply describe your app, and Replit will write the code, test it, and fix issues automatically, all within your browser.
Common errors and challenges
Visualizing decision trees can present a few common challenges, but they're usually straightforward to resolve.
- Fixing overly complex trees with the
max_depthparameter - A decision tree that grows too deep becomes a tangled mess, making it nearly impossible to read. This is also a classic sign of overfitting, where the model has learned the training data too well, including its noise. The simplest fix is to prune the tree by setting the
max_depthparameter in your classifier. Limiting the depth creates a cleaner visualization and often a more generalizable model. - Resolving feature name mismatch errors in
plot_tree - You might see an error if the list of names you provide to the
feature_namesparameter doesn't match the number of features your model was trained on. To fix this, ensure the list of names has the exact same length as the number of features in your training data. For example, if your data has four features, yourfeature_nameslist must contain four strings. - Fixing class name mismatches in tree visualization
- Similarly, providing an incorrect number of class names can lead to errors or mislabeled leaf nodes, making your tree's predictions seem nonsensical. Make sure the list you pass to the
class_namesargument corresponds correctly to your target variable's unique classes. The order matters, so confirm it aligns with how the model internally sorts the classes.
Fixing overly complex trees with max_depth parameter
When a decision tree grows without restrictions, it can become overly complex. The resulting visualization is often a tangled mess that's impossible to read—a classic sign of overfitting, where the model has learned noise instead of the underlying signal.
The following code demonstrates this problem by training a DecisionTreeClassifier without setting a max_depth. The output shows just how unreadable the tree becomes.
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier, plot_tree
import matplotlib.pyplot as plt
iris = load_iris()
# Creating a tree without depth restriction
clf = DecisionTreeClassifier().fit(iris.data, iris.target)
plt.figure(figsize=(10,6))
plot_tree(clf, filled=True)
plt.show()
Since no max_depth is set, the classifier keeps splitting until every leaf is pure, creating a tangled visual. The code below shows how a small change can produce a much clearer tree.
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier, plot_tree
import matplotlib.pyplot as plt
iris = load_iris()
# Limiting tree depth for better visualization
clf = DecisionTreeClassifier(max_depth=3).fit(iris.data, iris.target)
plt.figure(figsize=(10,6))
plot_tree(clf, filled=True)
plt.show()
By setting max_depth=3 in the DecisionTreeClassifier, the second code block effectively prunes the tree. This prevents it from growing too deep and creating a tangled, unreadable plot. The resulting visualization isn't just cleaner; it represents a model that's less likely to be overfit. You'll want to adjust max_depth whenever your tree visualizations become too complex to follow, as it's a direct way to control model complexity and improve interpretability.
Resolving feature name mismatch errors in plot_tree
A common error occurs when the list of names you pass to the feature_names parameter in plot_tree doesn't match the number of features in your data. This mismatch causes a ValueError, preventing the plot from rendering. The code below demonstrates this issue.
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier, plot_tree
import matplotlib.pyplot as plt
iris = load_iris()
clf = DecisionTreeClassifier(max_depth=3).fit(iris.data, iris.target)
# Incorrect: providing too few feature names
custom_names = ["sepal length", "sepal width"]
plt.figure(figsize=(10,6))
plot_tree(clf, feature_names=custom_names, filled=True)
plt.show()
The custom_names list only has two strings, while the model was trained on four features from the Iris dataset. This mismatch is what triggers the error. The following code demonstrates the correct approach.
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier, plot_tree
import matplotlib.pyplot as plt
iris = load_iris()
clf = DecisionTreeClassifier(max_depth=3).fit(iris.data, iris.target)
# Correct: providing all four feature names
custom_names = ["sepal length", "sepal width", "petal length", "petal width"]
plt.figure(figsize=(10,6))
plot_tree(clf, feature_names=custom_names, filled=True)
plt.show()
The fix is straightforward: the custom_names list now provides four names, perfectly matching the four features in the training data. This alignment resolves the error by giving the plot_tree function a label for every feature it needs to display. Always ensure the length of your feature_names list equals the number of features your model was trained on, especially after steps like feature selection where the count might change unexpectedly.
Fixing class name mismatches in tree visualization
A mismatch in class names can also break your visualization. If the list you provide to class_names doesn't match the number of unique classes in your target variable, you'll get an error or mislabeled leaves. The following code demonstrates this issue.
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier, plot_tree
import matplotlib.pyplot as plt
iris = load_iris()
clf = DecisionTreeClassifier(max_depth=3).fit(iris.data, iris.target)
# Class names don't match the number of classes (0, 1, 2)
wrong_class_names = ['Setosa', 'Versicolor']
plt.figure(figsize=(10,6))
plot_tree(clf, class_names=wrong_class_names, filled=True)
plt.show()
The wrong_class_names list provides only two names, while the Iris dataset has three classes. This mismatch between the labels and the target data is what causes the error. The corrected code below shows the proper alignment.
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier, plot_tree
import matplotlib.pyplot as plt
iris = load_iris()
clf = DecisionTreeClassifier(max_depth=3).fit(iris.data, iris.target)
# Correct class names for all three classes
correct_class_names = ['Setosa', 'Versicolor', 'Virginica']
plt.figure(figsize=(10,6))
plot_tree(clf, class_names=correct_class_names, filled=True)
plt.show()
The fix is to ensure the list passed to the class_names parameter contains a label for every unique class in your target data. By providing a list with all three Iris class names, the code now correctly aligns with the model's output. This kind of mismatch often occurs after preprocessing steps alter your target variable. Always confirm the number and order of your classes before plotting to ensure the leaf nodes are labeled correctly.
Real-world applications
Beyond debugging, these visualization techniques are essential for explaining model decisions in critical fields like medicine and feature analysis.
Explaining medical diagnoses with plot_tree
For high-stakes applications like medical diagnostics, plot_tree offers a straightforward way to visualize the specific factors driving a model's conclusion.
from sklearn.datasets import load_breast_cancer
from sklearn.tree import DecisionTreeClassifier, plot_tree
import matplotlib.pyplot as plt
cancer = load_breast_cancer()
clf = DecisionTreeClassifier(max_depth=3).fit(cancer.data, cancer.target)
plt.figure(figsize=(15, 10))
plot_tree(clf, feature_names=cancer.feature_names, class_names=['Malignant', 'Benign'],
filled=True, fontsize=12, max_depth=2)
plt.title("Decision Tree for Breast Cancer Diagnosis")
plt.show()
This code trains a DecisionTreeClassifier on the breast cancer dataset, limiting its growth to a max_depth of 3. It then uses plot_tree to create a visual that's easy to interpret.
- The
feature_namesandclass_namesparameters label the nodes, making the tree's logic transparent. - Notice that while the model is trained with a depth of three, the visualization itself is capped at
max_depth=2. This technique lets you display a simplified version of a more complex tree, making the top-level decisions easier to inspect.
Visualizing decision boundaries in feature space
Visualizing a tree's decision boundaries offers a different perspective, creating a map that shows exactly how the model carves up the feature space to separate different classes.
from sklearn.decomposition import PCA
import numpy as np
import matplotlib.pyplot as plt
# Reduce iris data to 2D for visualization
pca = PCA(n_components=2)
iris_2d = pca.fit_transform(iris.data)
# Train a tree on the 2D projection
tree_clf = DecisionTreeClassifier(max_depth=3).fit(iris_2d, iris.target)
# Create a mesh grid to visualize decision boundaries
x_min, x_max = iris_2d[:, 0].min() - 1, iris_2d[:, 0].max() + 1
y_min, y_max = iris_2d[:, 1].min() - 1, iris_2d[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, 0.1), np.arange(y_min, y_max, 0.1))
Z = tree_clf.predict(np.c_[xx.ravel(), yy.ravel()]).reshape(xx.shape)
plt.figure(figsize=(10, 8))
plt.contourf(xx, yy, Z, alpha=0.4)
plt.scatter(iris_2d[:, 0], iris_2d[:, 1], c=iris.target, s=20, edgecolor='k')
plt.title("Decision Boundaries of Tree Classifier")
plt.show()
This code creates a map of the model's decision-making process. Since you can't plot four dimensions, it first uses Principal Component Analysis (PCA) to compress the Iris dataset into two. A new decision tree is then trained on this simplified 2D data.
The visualization itself comes together in a few steps:
- A grid of points covering the plot area is created with
np.meshgrid. - The model predicts a class for every point on the grid.
plt.contourfdraws these predictions as colored regions, forming the decision boundaries.plt.scatteradds the actual data points on top.
Get started with Replit
Turn your knowledge into a real tool. Describe what you want to Replit Agent, like “a dashboard to plot a decision tree from a CSV” or “an app to see how max_depth changes a tree.”
The Agent writes the code, tests for errors, and deploys your application for you. Start building with Replit.
Create and deploy websites, automations, internal tools, data pipelines and more in any programming language without setup, downloads or extra tools. All in a single cloud workspace with AI built in.
Create and deploy websites, automations, internal tools, data pipelines and more in any programming language without setup, downloads or extra tools. All in a single cloud workspace with AI built in.

.png)

.png)