How to make a decision tree in Python

Build decision trees in Python. Our guide covers methods, tips, real-world uses, and debugging common errors.

How to make a decision tree in Python
Published on: 
Mon
Apr 6, 2026
Updated on: 
Wed
Apr 8, 2026
The Replit Team

Decision trees are a core machine learning concept. They help you model complex decisions and predict outcomes. Python, with its rich libraries, provides an excellent environment to build and train them.

In this article, we'll walk through the essential techniques to build your own decision tree. You'll get practical tips, see real-world applications, and learn how to debug your models effectively.

Basic decision tree with scikit-learn

from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split

iris = load_iris()
X_train, X_test, y_train, y_test = train_test_split(iris.data, iris.target, random_state=42)
clf = DecisionTreeClassifier(random_state=42)
clf.fit(X_train, y_train)
print(f"Accuracy: {clf.score(X_test, y_test):.2f}")--OUTPUT--Accuracy: 0.95

This code uses scikit-learn to build and test a simple decision tree on the classic Iris dataset. The process involves a few key steps:

  • Data Splitting: We use train_test_split to divide the data. This is crucial for testing the model's predictive power on data it hasn't seen before, which prevents overfitting.
  • Training: The fit method trains the DecisionTreeClassifier on the training portion of the data, teaching it to associate flower measurements with specific species.
  • Reproducibility: Setting random_state=42 ensures that the data split and model training are deterministic, so you'll get the same 95% accuracy every time you run the code.

Fundamental decision tree techniques

Now that you have a working model, you can take it further by visualizing it with export_graphviz, validating it with cross-validation, and tuning it with GridSearchCV.

Visualizing the decision tree with export_graphviz

from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier, export_graphviz
import graphviz

iris = load_iris()
clf = DecisionTreeClassifier(max_depth=3, random_state=42)
clf.fit(iris.data, iris.target)
dot_data = export_graphviz(clf, feature_names=iris.feature_names,
class_names=iris.target_names, filled=True)
graph = graphviz.Source(dot_data)--OUTPUT--[Graphviz visualization of the decision tree structure]

Visualizing your decision tree is a great way to understand its internal logic. The export_graphviz function converts your trained model into a DOT format—a graph description language—which the graphviz library then renders as an image. Notice we set max_depth=3 to keep the tree simple and readable.

  • The feature_names and class_names arguments label the nodes, clarifying the decision criteria and outcomes at each step.
  • Setting filled=True colors the nodes to visually represent the majority class for each split.

Using cross-validation for more reliable evaluation

from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import cross_val_score

iris = load_iris()
clf = DecisionTreeClassifier(random_state=42)
cv_scores = cross_val_score(clf, iris.data, iris.target, cv=5)
print(f"Cross-validation scores: {cv_scores}")
print(f"Average accuracy: {cv_scores.mean():.2f}")--OUTPUT--Cross-validation scores: [0.96666667 0.96666667 0.9 0.93333333 1. ]
Average accuracy: 0.95

While a single train-test split is useful, cross-validation provides a more robust evaluation of your model's performance. The cross_val_score function handles this for you automatically.

  • It splits the data into a specified number of "folds"—here, five, because we set cv=5.
  • The model is then trained and tested five times. In each run, a different fold serves as the test set while the remaining four are used for training.
  • By averaging the scores from all five runs with cv_scores.mean(), you get a more stable and trustworthy measure of accuracy.

Optimizing with GridSearchCV for hyperparameter tuning

from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import GridSearchCV

iris = load_iris()
param_grid = {'max_depth': [3, 5, 10], 'min_samples_split': [2, 5, 10]}
grid_search = GridSearchCV(DecisionTreeClassifier(random_state=42), param_grid, cv=5)
grid_search.fit(iris.data, iris.target)
print(f"Best parameters: {grid_search.best_params_}")
print(f"Best accuracy: {grid_search.best_score_:.2f}")--OUTPUT--Best parameters: {'max_depth': 3, 'min_samples_split': 2}
Best accuracy: 0.97

Fine-tuning your model's hyperparameters is key to improving its performance. GridSearchCV automates this process by testing every combination of parameters you provide in a param_grid. It uses cross-validation to find which settings work best.

  • The param_grid dictionary defines the search space—here, you're testing different values for max_depth and min_samples_split.
  • After running fit, you can access the optimal settings with best_params_, which in this case led to a 97% accuracy.

Advanced decision tree implementations

With your model tuned, you can now push its performance further by analyzing feature importance, handling imbalanced data, and using more powerful ensemble methods.

Analyzing feature importance for better understanding

from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
import numpy as np

iris = load_iris()
clf = DecisionTreeClassifier(random_state=42)
clf.fit(iris.data, iris.target)
importances = clf.feature_importances_
indices = np.argsort(importances)[::-1]
for i in range(len(iris.feature_names)):
print(f"{iris.feature_names[indices[i]]}: {importances[indices[i]]:.4f}")--OUTPUT--petal width (cm): 0.5425
petal length (cm): 0.4575
sepal length (cm): 0.0000
sepal width (cm): 0.0000

Understanding which features your model values most is crucial for interpretation. After training, the DecisionTreeClassifier stores this information in its feature_importances_ attribute. Each feature gets a score reflecting its impact on the model's decisions. The code then sorts these scores to rank the features from most to least influential.

  • For the Iris dataset, the model relies entirely on petal width (cm) and petal length (cm). The sepal measurements have an importance of zero, meaning they weren't used in any decision splits.

Handling imbalanced data with SMOTE

from sklearn.datasets import make_classification
from imblearn.over_sampling import SMOTE
import numpy as np

X, y = make_classification(n_samples=1000, weights=[0.9, 0.1], random_state=42)
smote = SMOTE(random_state=42)
X_balanced, y_balanced = smote.fit_resample(X, y)
print(f"Original class distribution: {np.bincount(y)}")
print(f"Balanced class distribution: {np.bincount(y_balanced)}")--OUTPUT--Original class distribution: [900 100]
Balanced class distribution: [900 900]

Decision trees can struggle with imbalanced data, where one class vastly outnumbers another. This biases the model toward the majority class. SMOTE (Synthetic Minority Over-sampling Technique) addresses this by creating new, synthetic samples for the minority class instead of just duplicating them.

  • The code uses make_classification to generate a dataset with a 900-to-100 class imbalance.
  • Applying smote.fit_resample balances the dataset by oversampling the minority class, resulting in an even 900-to-900 split. This ensures your model learns the patterns of both classes effectively.

Boosting performance with RandomForestClassifier

from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split

iris = load_iris()
X_train, X_test, y_train, y_test = train_test_split(iris.data, iris.target, random_state=42)
dt = DecisionTreeClassifier(random_state=42).fit(X_train, y_train)
rf = RandomForestClassifier(n_estimators=100, random_state=42).fit(X_train, y_train)
print(f"Decision Tree accuracy: {dt.score(X_test, y_test):.2f}")
print(f"Random Forest accuracy: {rf.score(X_test, y_test):.2f}")--OUTPUT--Decision Tree accuracy: 0.95
Random Forest accuracy: 0.97

A single decision tree is good, but a Random Forest is often better. It’s an ensemble method, meaning it builds many decision trees—100 in this case, set by n_estimators=100—and aggregates their predictions. This collective approach makes the model more robust and less prone to errors from any single tree.

  • The code directly compares a single DecisionTreeClassifier with a RandomForestClassifier on the same data.
  • By averaging the results from all its trees, the Random Forest reduces overfitting and improves predictive power, boosting accuracy from 95% to 97%.

Move faster with Replit

Replit is an AI-powered development platform that comes with all Python dependencies pre-installed, so you can skip setup and start coding instantly. You don't need to worry about managing environments or installations.

The techniques in this article are powerful building blocks. With Agent 4, you can move from piecing them together to building complete applications. It takes your description and handles the code, databases, APIs, and deployment.

  • A customer churn predictor that uses feature_importances_ to identify the most significant reasons customers leave.
  • A fraud detection system that leverages SMOTE to effectively train a model on rare but critical fraudulent activities.
  • An automated model tuner that runs GridSearchCV to find the optimal max_depth and min_samples_split for your classifier.

Simply describe your app, and Replit will write the code, test it, and fix issues automatically, all within your browser.

Common errors and challenges

Building decision trees often involves a few common hurdles, but they're straightforward to overcome once you know what to look for.

  • Fixing overfitting with max_depth: Overfitting happens when your model learns the training data too well, including its noise, and then fails to generalize to new data. You can prevent this by tuning the max_depth hyperparameter, which limits the tree's complexity and forces it to capture only the most significant patterns.
  • Handling categorical features with OneHotEncoder: Decision trees in scikit-learn require numerical inputs, so they can't process text labels directly. Use OneHotEncoder to convert categorical data into a binary format (columns of 0s and 1s) that the model can understand.
  • Troubleshooting prediction shape mismatch with reshape: You'll often see errors when predicting a single sample because its array shape is wrong. This is because the model expects a 2D array, but a single sample is often a 1D array. Use the reshape(1, -1) method to adjust the data's dimensions to the format your model expects.

Fixing overfitting with proper max_depth parameter

Overfitting is a classic trap where a model memorizes training data instead of learning general patterns. This results in high training accuracy but poor performance on unseen test data. The code below demonstrates this problem with an unconstrained decision tree.

from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split

iris = load_iris()
X_train, X_test, y_train, y_test = train_test_split(iris.data, iris.target, random_state=42)
# No limit on tree depth - will likely overfit
clf = DecisionTreeClassifier(random_state=42)
clf.fit(X_train, y_train)
print(f"Training accuracy: {clf.score(X_train, y_train):.2f}")
print(f"Test accuracy: {clf.score(X_test, y_test):.2f}")

Because the DecisionTreeClassifier is unconstrained, it achieves perfect training accuracy but lower test accuracy. This gap is a clear sign of overfitting. The code below shows how a simple adjustment to the model's parameters closes this performance gap.

from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split

iris = load_iris()
X_train, X_test, y_train, y_test = train_test_split(iris.data, iris.target, random_state=42)
# Set max_depth to prevent overfitting
clf = DecisionTreeClassifier(max_depth=3, random_state=42)
clf.fit(X_train, y_train)
print(f"Training accuracy: {clf.score(X_train, y_train):.2f}")
print(f"Test accuracy: {clf.score(X_test, y_test):.2f}")

By setting max_depth=3, you limit the tree's complexity, forcing the model to focus on broader patterns instead of memorizing training data. The result is that training and test accuracies become much more aligned, showing the model now generalizes better. Keep an eye out for a large gap between training and test scores—it’s a classic sign that you need to rein in your model's complexity with a parameter like max_depth.

Handling categorical features with OneHotEncoder

Decision trees in scikit-learn work with numbers, not text. This means you can't directly use categorical features like 'red' or 'blue' in your model. Attempting to fit the model with this data will raise an error, as you'll see below.

from sklearn.tree import DecisionTreeClassifier
import pandas as pd

# Create dataset with categorical features
data = pd.DataFrame({
'feature1': [1.2, 0.5, 3.1, 2.0],
'feature2': ['red', 'blue', 'red', 'green'],
'target': [0, 1, 0, 1]
})

X = data[['feature1', 'feature2']]
y = data['target']

# This will fail because 'feature2' is categorical
clf = DecisionTreeClassifier()
clf.fit(X, y)

The fit method can't process text values like 'red' and 'blue' in the feature2 column, which triggers an error. The following code demonstrates how to properly prepare this data for the model before training.

from sklearn.tree import DecisionTreeClassifier
from sklearn.preprocessing import OneHotEncoder
from sklearn.compose import ColumnTransformer
import pandas as pd

# Create dataset with categorical features
data = pd.DataFrame({
'feature1': [1.2, 0.5, 3.1, 2.0],
'feature2': ['red', 'blue', 'red', 'green'],
'target': [0, 1, 0, 1]
})

# Properly encode categorical features
preprocessor = ColumnTransformer(
transformers=[('cat', OneHotEncoder(), [1])],
remainder='passthrough'
)
X_encoded = preprocessor.fit_transform(data[['feature1', 'feature2']])
y = data['target']

clf = DecisionTreeClassifier()
clf.fit(X_encoded, y)

To fix this, you'll need to convert the text data into a numerical format. The OneHotEncoder handles this by creating new binary columns for each category. You can use a ColumnTransformer to apply this encoding only to the categorical feature, while remainder='passthrough' keeps the numerical columns as they are. Your model can then be trained on this newly encoded data without any errors. This is a common step whenever your dataset contains non-numeric features.

Troubleshooting prediction shape mismatch with reshape

A common error you'll encounter is a shape mismatch when making a single prediction. scikit-learn models expect a 2D array of samples, but a single data point is often a 1D array. This mismatch will cause the predict method to fail.

The code below shows exactly what happens when you pass a single sample with the wrong dimensions.

from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier

iris = load_iris()
clf = DecisionTreeClassifier(random_state=42)
clf.fit(iris.data, iris.target)

# Trying to predict with incorrect shape
new_sample = [5.1, 3.5, 1.4, 0.2]
prediction = clf.predict(new_sample) # This will fail
print(f"Prediction: {prediction}")

The predict method is designed to process a batch of samples. Passing a single sample directly, as with new_sample, creates a structural mismatch and triggers an error. The following code shows how to format the input correctly.

from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
import numpy as np

iris = load_iris()
clf = DecisionTreeClassifier(random_state=42)
clf.fit(iris.data, iris.target)

# Reshape data for prediction - each sample needs to be 2D
new_sample = np.array([5.1, 3.5, 1.4, 0.2]).reshape(1, -1)
prediction = clf.predict(new_sample)
print(f"Prediction: {prediction}")

The fix is simple: you need to reshape your single sample into a 2D array. Using NumPy's reshape(1, -1) method wraps your single data point in another array, matching the structure the predict method expects. This tells the model you're passing one sample with an inferred number of features. Keep an eye out for this error whenever you're predicting a single observation after training on a full dataset. It's a common and easy-to-fix mismatch.

Real-world applications

With the technical challenges solved, you can now apply these models to predict practical outcomes like customer churn and loan defaults.

Using decision trees for customer churn prediction

You can train a decision tree on customer data, like account tenure and monthly charges, to predict which users are likely to churn.

import pandas as pd
from sklearn.tree import DecisionTreeClassifier

# Sample customer data (1=churned, 0=retained)
customers = pd.DataFrame({
'tenure_months': [2, 7, 5, 1, 8, 2, 4],
'monthly_charge': [65, 45, 85, 95, 35, 75, 50],
'churn': [1, 0, 1, 1, 0, 1, 0]
})
model = DecisionTreeClassifier(max_depth=2)
model.fit(customers[['tenure_months', 'monthly_charge']], customers['churn'])
print(f"New customer churn prediction: {model.predict([[3, 60]])}")

This code builds a model to forecast customer churn using a sample dataset created with a pandas DataFrame. The model learns from features like tenure_months and monthly_charge to predict whether a customer will stay or leave.

  • A DecisionTreeClassifier is initialized with max_depth=2 to keep the model simple and interpretable.
  • The fit method trains the model on the historical data, teaching it the patterns associated with churn.
  • Finally, predict is used to forecast the outcome for a new, unseen customer.

Predicting loan defaults with feature_importances_

In finance, you can use feature_importances_ to identify which factors in a loan application are most predictive of a default.

import numpy as np
from sklearn.tree import DecisionTreeClassifier

# Generate synthetic loan data (age, income, credit_score, default)
np.random.seed(42)
X = np.column_stack([
np.random.normal(35, 10, 1000), # age
np.random.normal(60000, 20000, 1000), # income
np.random.normal(700, 100, 1000) # credit score
])
y = ((X[:, 0] < 25) | (X[:, 1] < 40000) | (X[:, 2] < 600)).astype(int)

model = DecisionTreeClassifier(max_depth=3).fit(X, y)
for name, importance in zip(['Age', 'Income', 'Credit Score'], model.feature_importances_):
print(f"{name}: {importance:.2f}")

This code generates synthetic loan data using NumPy. It creates features like `age`, `income`, and `credit score` with np.random.normal. The target variable, `default`, is determined by a clear set of rules using logical operators like | (or).

  • A DecisionTreeClassifier is trained on this data, with max_depth=3 to control its complexity.
  • The model learns the relationships between the features and the default outcome you defined.
  • Finally, the code inspects the trained model to see how it weighed each feature when making decisions.

Get started with Replit

Now, turn what you've learned into a real tool. Give Replit Agent a prompt like, “build a churn predictor that shows feature_importances_,” or “create a loan risk calculator that uses SMOTE to balance the data.”

It will write the code, test for errors, and deploy your application directly from your browser. Start building with Replit.

Get started free

Create and deploy websites, automations, internal tools, data pipelines and more in any programming language without setup, downloads or extra tools. All in a single cloud workspace with AI built in.

Get started free

Create and deploy websites, automations, internal tools, data pipelines and more in any programming language without setup, downloads or extra tools. All in a single cloud workspace with AI built in.