import pandas as pd
import numpy as np
import plotly.graph_objects as go
from plotly.offline import download_plotlyjs, init_notebook_mode, plot, iplot
Sometimes I've found processes can get pretty complicated to explain. During those moments of quiet exhasperation, there's only one thing to do... MAKE A SANKEY PLOT!
Seriously, though, Sankey plots can often help visualize something that would be much more difficult to explain in words. For example, let's say that I want to take a group of elementary school students on a field trip, and we had them vote on whether they wanted to go to a museum or an aquarium.
We can illustrate this using a Sankey diagram, which shows a flow process from one categorical variable to another. Here's a simple illustration of the above example I made using Draw.io.
As you can see, we started with 175 total students on the left. We divided those 175 students into which grade they were in, whether they were girls or boys, and finally, which location they voted on for the field trip.
The plot I made by hand is good, but I'd really prefer to code it if possible. To do this, we're going to use the plotly package in python.
Examining the documentation for Sankey diagrams at the Plotly website, we can see that we need four pieces of information to make a Sankey diagram.
This might seem complicated, but don't fret! First, we need our node labels, which is just the text represented in our boxes above.
We'll read the information from the chart we made starting from the top and moving down, and from the left moving right, like we're reading columns in a newspaper.
node_labels = ['Students', '4th Grade', '5th Grade', 'Girls', 'Boys', 'Aquarium', 'Museum']
We also need to create a matching unique identifier to each of our node labels. To do this, we just need to enumerate over each entry in our node_labels variable.
node_ids = {y:x for x, y in enumerate(node_labels)}
node_ids
{'Students': 0, '4th Grade': 1, '5th Grade': 2, 'Girls': 3, 'Boys': 4, 'Aquarium': 5, 'Museum': 6}
Now we need our source values, target values, and numeric values.
Our source values and target values will use the exact same strings as we used in our node labels.
If you examine the arrows in our diagram above, the left side of the arrow accounts for our source values. For example, the "Students" box has two left sides of arrows attached to it. That means we need to account for it being the source of a flow output two times. We can go through the whole plot in such a way.
source = ['Students', 'Students', '4th Grade', '4th Grade', '5th Grade', '5th Grade',
'Girls', 'Girls', 'Boys', 'Boys']
Notice that the source list does not include either "Museum" or "Aquarium", because neither of those boxes are the origin points for any arrows.
Now we need to enter the target values. This is similar to our last task, except this time we're accounting for the end points of the arrows. Again, we're moving from top to bottom and from left to right.
In this case, the "Students" box has zero arrows that end on it, so we'll begin with the "4th Grade" box. The "4th Grade" box has one arrow ending on it, so we list it one time. Again, we go through the diagram accounting for each arrow endpoint.
target = ['4th Grade', '5th Grade', 'Girls', 'Boys', 'Girls', 'Boys',
'Aquarium', 'Museum', 'Aquarium', 'Museum']
Now we just need to incorporate the values we have listed in the diagram for each flow process. Again, we go through the diagram from top to bottom, left to right.
If you get confused, you can always just look at each number entry in our source and target list. Our first entry in our source list is "Students", and our first entry in our target list is "4th Grade", so our first entry in our values list will be the value of students in the 4th grade.
value = [75, 100, 40, 35, 45, 55, 65, 20, 30, 60]
Finally, we need to apply our enumerated unique labels to our sources and targets.
source_node = [node_ids[x] for x in source]
target_node = [node_ids[x] for x in target]
Now that we have all of the components to our Sankey diagram, we're ready to plot!
fig = go.Figure(
data=[go.Sankey(node = dict(label = node_labels),
link = dict(
source = source_node,
target = target_node,
value = value
))])
fig.show()
Because this is a static html page, the above diagram is just a png image. However, in a jupyter notebook, the diagram will be interactive and responsive to mouse movement. In the notebook, you can move the bars around to make the diagram more to your liking. After you are pleased with the positioning, you can click the camera button in the upper right corner of the plot to save the file as an image.
Using this process, you can make Sankey plots that are pretty complex! You can also incorporate sample sizes into your node labels so that each part of the process is elucidated in the exported image. For example, here is the code for a Sankey plot that I created when sorting images for a computer vision project I worked on.
node_labels = ['NRA - n = 845,015', 'Everytown - n = 259,098', 'FAIR - n = 415,262', 'UWD - n = 129,985',
'NRA - n = 727,032', 'Everytown - n = 241,981','FAIR - n = 331,946', 'UWD - n = 125,201', 'Removed - n = 223,200',
'NRA - n = 263,021', 'Everytown - n = 108,975', 'FAIR - n = 82,955', 'UWD - n = 56,218', 'Removed - n = 914,991',
'NRA - n = 212,558', 'Everytown - n = 95,397', 'FAIR - n = 69,818', 'UWD - n = 49,583', 'Removed - n = 83,813',
'NRA - n = 171,618', 'Everytown - n = 82,218', 'FAIR - n = 57,687', 'UWD - n = 43,859', 'Removed - n = 71,974',
'Final - n = 313,302', 'Removed - n = 42,080']
node_ids = {y:x for x, y in enumerate(node_label)}
node_ids
{'NRA - n = 845,015': 0, 'Everytown - n = 259,098': 1, 'FAIR - n = 415,262': 2, 'UWD - n = 129,985': 3, 'NRA - n = 727,032': 4, 'Everytown - n = 241,981': 5, 'FAIR - n = 331,946': 6, 'UWD - n = 125,201': 7, 'Removed - n = 223,200': 8, 'NRA - n = 263,021': 9, 'Everytown - n = 108,975': 10, 'FAIR - n = 82,955': 11, 'UWD - n = 56,218': 12, 'Removed - n = 914,991': 13, 'NRA - n = 212,558': 14, 'Everytown - n = 95,397': 15, 'FAIR - n = 69,818': 16, 'UWD - n = 49,583': 17, 'Removed - n = 83,813': 18, 'NRA - n = 171,618': 19, 'Everytown - n = 82,218': 20, 'FAIR - n = 57,687': 21, 'UWD - n = 43,859': 22, 'Removed - n = 71,974': 23, 'Final - n = 313,302': 24, 'Removed - n = 42,080': 25}
source = ['NRA - n = 845,015','Everytown - n = 259,098','FAIR - n = 415,262','UWD - n = 129,985',
'NRA - n = 845,015','Everytown - n = 259,098','FAIR - n = 415,262','UWD - n = 129,985',
'NRA - n = 727,032','Everytown - n = 241,981','FAIR - n = 331,946','UWD - n = 125,201',
'NRA - n = 727,032','Everytown - n = 241,981','FAIR - n = 331,946','UWD - n = 125,201',
'NRA - n = 263,021','Everytown - n = 108,975','FAIR - n = 82,955','UWD - n = 56,218',
'NRA - n = 263,021','Everytown - n = 108,975','FAIR - n = 82,955','UWD - n = 56,218',
'NRA - n = 212,558','Everytown - n = 95,397','FAIR - n = 69,818','UWD - n = 49,583',
'NRA - n = 212,558','Everytown - n = 95,397','FAIR - n = 69,818','UWD - n = 49,583',
'NRA - n = 171,618','Everytown - n = 82,218','FAIR - n = 57,687','UWD - n = 43,859',
'NRA - n = 171,618','Everytown - n = 82,218','FAIR - n = 57,687','UWD - n = 43,859']
target = ['NRA - n = 727,032','Everytown - n = 241,981','FAIR - n = 331,946','UWD - n = 125,201',
'Removed - n = 223,200','Removed - n = 223,200','Removed - n = 223,200','Removed - n = 223,200',
'NRA - n = 263,021','Everytown - n = 108,975','FAIR - n = 82,955','UWD - n = 56,218',
'Removed - n = 914,991','Removed - n = 914,991','Removed - n = 914,991','Removed - n = 914,991',
'NRA - n = 212,558','Everytown - n = 95,397','FAIR - n = 69,818','UWD - n = 49,583',
'Removed - n = 83,813','Removed - n = 83,813','Removed - n = 83,813','Removed - n = 83,813',
'NRA - n = 171,618','Everytown - n = 82,218','FAIR - n = 57,687','UWD - n = 43,859',
'Removed - n = 71,974','Removed - n = 71,974','Removed - n = 71,974','Removed - n = 71,974',
'Final - n = 313,302','Final - n = 313,302','Final - n = 313,302','Final - n = 313,302',
'Removed - n = 42,080','Removed - n = 42,080','Removed - n = 42,080','Removed - n = 42,080']
value = [ 727032, 241981, 331946, 125201, 117983, 17117, 83316, 4784,
263021, 108975, 82955, 56218, 464011, 133006, 248991, 68983,
212558,95397,69818,49583,50463,13578,13137,6635,
171618,82218,57687,43859,40940,13179,12131,5724,
158537,73480,44787,36498,13081,8738,12900,7361]
source_node = [node_dict[x] for x in source]
target_node = [node_dict[x] for x in target]
fig = go.Figure(
data=[go.Sankey(node = dict(label = node_label),
link = dict(
source = source_node,
target = target_node,
value = value
))])
fig.show();
Obviously this is a little messy. You can open the diagram in a new window, which makes it easier to organize the components before saving the diagram.
plot(fig, image_filename='sankey_plot_1', image='png', image_width=1000, image_height=600);
After a little love, this is the resulting image.