How To Create Subplots in Python Using Matplotlib

Hey - Nick here! This page is a free excerpt from my $199 course Python for Finance, which is 50% off for the next 50 students.

If you want the full course, click here to sign up.

In this lesson, you will learn how to create subplots in Python using matplotlib.

The Imports You’ll Need For This Lesson

As before, you will need the following imports for this lesson:


import pandas as pd

import matplotlib.pyplot as plt

%matplotlib inline

from IPython.display import set_matplotlib_formats

set_matplotlib_formats('retina')

We will also be importing several data sets for the examples near the end of this lesson. Those imports will be covered later in this tutorial.

What is a Subplot?

There are many cases where you will want to generate a plot that contains several smaller plots within it. That is exactly what a subplot is!

A common version of the subplot is the 4x4 subplot. An example of the 4x4 subplot is below:

An Example of a Subplot in Matplotlib

Subplots can be very complicated to create when done properly. As an example, consider the code that I used to create the above 4x4 subplot:


import numpy as np

import pandas as pd

import matplotlib.pyplot as plt

%matplotlib inline

from datetime import datetime

tech_stocks_data = pd.read_csv('https://raw.githubusercontent.com/nicholasmccullum/python-visualization/master/tech_stocks/GOOG_MSFT_FB_AMZN_data.csv')

tech_stocks_data.sort_values('Period', ascending = True, inplace = True)

google = tech_stocks_data['Alphabet Inc Price']

amazon = tech_stocks_data['Amazon.com Inc Price']

facebook = tech_stocks_data['Facebook Inc Price']

microsoft = tech_stocks_data['Microsoft Corp Price']

dates = tech_stocks_data['Period']

x = []

for date in tech_stocks_data['Period']:

    x.append(datetime.strptime(date, '%Y-%m-%d %H:%M:%S').year)

    

plt.figure(figsize=(16,12))

#Plot 1

plt.subplot(2,2,1)

plt.xticks(np.arange(0, len(x) + 1)[::365], x[::365])

plt.plot(dates, google)

plt.title('Alphabet (GOOG) (GOOGL) Stock Price')

#Plot 2

plt.subplot(2,2,2)

plt.xticks(np.arange(0, len(x) + 1)[::365], x[::365])

plt.plot(dates, amazon)

plt.title('Amazon (AMZN)) Stock Price')

#Plot 3

plt.subplot(2,2,3)

plt.xticks(np.arange(0, len(x) + 1)[::365], x[::365])

plt.plot(dates, facebook)

plt.title('Facebook (FB) Stock Price')

#Plot 4

plt.subplot(2,2,4)

plt.xticks(np.arange(0, len(x) + 1)[::365], x[::365])

plt.plot(dates, microsoft)

plt.title('Microsoft (MSFT) Stock Price')

This might seem overwhelming. We will work through the process of creating subplots step-by-step through the remainder of this lesson.

How To Create Subplots in Python Using Matplotlib

We can create subplots in Python using matplotlib with the subplot method, which takes three arguments:

  • nrows: The number of rows of subplots in the plot grid.
  • ncols: The number of columns of subplots in the plot grid.
  • index: The plot that you have currently selected.

The nrows and ncols arguments are relatively straightforward, but the index argument may require some explanation. It starts at 1 and moves through each row of the plot grid one-by-one. When it reaches the end of a row, it will move down to the first entry of the next row.

A few examples of selecting specific subplots within a plot grid are shown below:


plt.subplot(3,3,5)

#Selects the middle entry of the second row in the 3x3 subplot grid

plt.subplot(1,2,2)

#Selects the second entry in a 1x2 subplot grid

plt.subplot(4,4,16)

#Selects the last entry in a 4x4 subplot grid

We will work through two examples of how to create subplot grids before concluding this lesson.

Example #1: A 2x2 Subplot Grid

First, let’s import the Iris data set:


iris_data = pd.read_json('https://raw.githubusercontent.com/nicholasmccullum/python-visualization/master/iris/iris.json')

Using the Iris data set, let’s create a 2x2 subplot with a subplot for each of the following variables (in the order they’re listed):

  • sepalLength
  • sepalWidth
  • petalLength
  • petalWidth

Make each subplot a histogram with X bins. Make sure to give each subplot a reasonable title so that an outside reader could understand the data.

Give this a try yourself before proceeding!

Once you have attempted this on your own, you can view the code below for a full solution:


plt.subplot(2,2,1)

plt.hist(iris_data['sepalLength'], bins = 15)

plt.title('A Histogram of Sepal Lengths from the Iris Data Set')

plt.subplot(2,2,2)

plt.hist(iris_data['sepalWidth'], bins = 15)

plt.title('A Histogram of Sepal Widths from the Iris Data Set')

plt.subplot(2,2,3)

plt.hist(iris_data['petalLength'], bins = 15)

plt.title('A Histogram of Petal Lengths from the Iris Data Set')

plt.subplot(2,2,4)

plt.hist(iris_data['petalWidth'], bins = 15)

plt.title('A Histogram of Petal Widths from the Iris Data Set')

Your First Subplot!

Example #2: A 2x3 Subplot Grid

For this example, let’s again import the wine quality data set from the UCI Machine Learning Repository. The import is below:


wine_data = pd.read_csv('http://archive.ics.uci.edu/ml/machine-learning-databases/wine-quality/winequality-red.csv', sep=';')

Let’s create a 2x3 subplot with the following plots (in the order they’re listed):

  • chlorides
  • quality
  • alcohol
  • density
  • total sulfur dioxide
  • citric acid

Let’s make each subplot a scatterplot, with the x-variable for each scatterplot being fixed acidity. Name each plot with an appropriate title for an outside reader to understand it.

Give this a try yourself before proceeding!

Once you have attempted this on your own, you can view the code below for a full solution:


x = wine_data['fixed acidity']

plt.subplot(2,3,1)

plt.scatter(x, wine_data['chlorides'])

plt.title('Chlorides plotted against Fixed Acidity')

plt.subplot(2,3,2)

plt.scatter(x, wine_data['quality'])

plt.title('Quality plotted against Fixed Acidity')

plt.subplot(2,3,3)

plt.scatter(x, wine_data['alcohol'])

plt.title('Alcohol plotted against Fixed Acidity')

plt.subplot(2,3,4)

plt.scatter(x, wine_data['density'])

plt.title('Density plotted against Fixed Acidity')

plt.subplot(2,3,5)

plt.scatter(x, wine_data['total sulfur dioxide'])

plt.title('Total Sulfur Dioxide plotted against Fixed Acidity')

plt.subplot(2,3,6)

plt.scatter(x, wine_data['citric acid'])

plt.title('Citric Acid plotted against Fixed Acidity')

Your Second Subplot!

The plt.tight_layout Method.

While creating Python visualizations, you will often encounter situations where your subplots have axis labels that overlap one another.

As an example, let’s run the following code to create 25 empty matplotlib plots:


#Import the necessary Python libraries

import matplotlib.pyplot as plt

import numpy as np

#Set matplotlib to display plots inline in the Jupyter Notebook

%matplotlib inline

#Resize the matplotlib canvas

plt.figure(figsize=(16,12))

#Create 16 empty plots

for x in (np.arange(25)+1):

    plt.subplot(5,5,x)

    plt.plot()

Here is the output of this code:

Empty Subplots in Matplotlib

As you can see, the axis labels in these subplots overlap one another. This is visually unappealing.

If you add the plt.tight_layout() statement to the end of this code block, this problem resolves itself. Here is the same output with the added statement:

Empty Subplots in Matplotlib

Moving On

In this lesson, we learned how to create subplot grids in Python using matplotlib. Subplots are an important part of your Python visualization toolkit.