#Importing pandas to handle dataframe
import pandas as pd
# Suppress pandas warnings
import warnings
"ignore") warnings.filterwarnings(
(A) Introduction
Data visualization plays a vital role in various domains such as data analytics, data science, data dashboarding, and exploratory/statistical analysis. Within the Python and R ecosystems, there are several popular visualization libraries commonly used. These include:
Among these, the widely used library is the Plotly Graphing Library, which offers libraries in multiple languages, high-quality scientific/non-scientific graphs, and easily shareable interactive plots.
In this post, I will be discussing an intriguing plot called the Sunburst Chart. Sunburst charts provide an interactive visualization of layered information, allowing for an enhanced understanding of complex data structures.
(B) Sunburst Chart
A sunburst chart is a powerful visualization tool used to represent hierarchical datasets. In a hierarchical dataset, there exists a parent-child relationship among the features or variables, resembling a tree-like structure. To generate a sunburst plot using Plotly, you can leverage the capabilities of either plotly.express or plotly.graph_objects libraries.
Let’s consider an example dataframe (dummy data for demonstration purposes) with a tree-like structure, where the columns or features exhibit parent-child relationships with other columns.
General Dataset: This dataframe contains classes and values organized in columns, as depicted in the sample data provided. Sunburst DataFrame: This hierarchical dataframe defines the logical parent-child relationships between columns and their corresponding values.
Now, let’s delve into how this data would appear by visualizing it using a sunburst chart.
(C) Datasets
The following dataset is a dummy data for demonstration. Usually, you may come accross, this kind of data while working on a data science/analytics projects.
= pd.read_csv("../data/dummy_data.csv")
data data.head()
Country | State | City | Population | |
---|---|---|---|---|
0 | India | INMP | A1 | 512 |
1 | India | INCG | B2 | 12201 |
2 | India | INCG | M1 | 9021 |
3 | USA | USNY | C2 | 812 |
4 | USA | USNY | N1 | 821 |
The dataset is not in hierachical form. The sunburst chart needs a parent, child and value variable for generating the plot. Hence, we need to convert the table into a ‘chart-acceptable’ format. The following function performs the job. The function is modified version of original function defined at Plotly’s documentation, to know more about this please visit here.
def build_hierarchical_dataframe(df, levels, value_column, metric):
"""
Build a hierarchy of levels for Sunburst.
- Levels are given starting from the bottom to the top of the hierarchy,
i.e. the last level corresponds to the root.
- Input :
- df : pandas dataframe
- levels : list of column names in the order, child to root.
- value_column : string value corresponding to value of column to display in chart.
- metric : string value equal to "sum" or "count".
- Output:
- df_all_trees : pandas dataframe for sunburst with columns, ['id', 'parent', 'value'].
"""
= pd.DataFrame(columns=['id', 'parent', 'value'])
df_all_trees
for i, level in enumerate(levels):
= pd.DataFrame(columns=['id', 'parent', 'value'])
df_tree ## Groupby based upon metric chosen
if metric=="count":
= df.groupby(levels[i:]).count()
dfg else:
= df.groupby(levels[i:]).sum()
dfg
= dfg.reset_index()
dfg 'id'] = dfg[level].copy()
df_tree[
## Set parent of the levels
if i < len(levels) - 1:
'parent'] = dfg[levels[i+1]].copy()
df_tree[else:
'parent'] = 'Total'
df_tree[
'value'] = dfg[value_column]
df_tree[= pd.concat([df_all_trees, df_tree], ignore_index=True)
df_all_trees
## Value calculation for parent
if metric=="count":
= pd.Series(dict(id='Total', parent='', value=df[value_column].count()))
total else:
= pd.Series(dict(id='Total', parent='', value=df[value_column].sum()))
total
## Add frames one below the other to form the final dataframe
= pd.concat([df_all_trees, pd.DataFrame([total])], ignore_index=True)
df_all_trees return df_all_trees
= ['City', 'State', 'Country']
levels = 'Population'
value_column = "sum" metric
Hierarchical Sum dataframe
This dataframe represents total population accross Country, State and City under study.
=build_hierarchical_dataframe(data, levels, value_column, metric="sum")
df_sum df_sum
id | parent | value | |
---|---|---|---|
0 | A1 | INMP | 512 |
1 | B2 | INCG | 12201 |
2 | C2 | USNY | 812 |
3 | D1 | INSD | 9104 |
4 | E2 | INGD | 132 |
5 | F1 | USSF | 82 |
6 | G2 | INSA | 5121 |
7 | H1 | INAS | 1232 |
8 | I2 | USHF | 8841 |
9 | J1 | INSR | 11 |
10 | K2 | INCQ | 1236 |
11 | L3 | USSF | 1200 |
12 | M1 | INCG | 9021 |
13 | N1 | USNY | 821 |
14 | O2 | USNY | 128 |
15 | P1 | INSD | 20 |
16 | Q1 | USXO | 4120 |
17 | R1 | USXO | 60 |
18 | S1 | INGD | 6012 |
19 | INAS | India | 1232 |
20 | INCG | India | 21222 |
21 | INCQ | India | 1236 |
22 | INGD | India | 6144 |
23 | INMP | India | 512 |
24 | INSA | India | 5121 |
25 | INSD | India | 9124 |
26 | INSR | India | 11 |
27 | USHF | USA | 8841 |
28 | USNY | USA | 1761 |
29 | USSF | USA | 1282 |
30 | USXO | USA | 4180 |
31 | India | Total | 44602 |
32 | USA | Total | 16064 |
33 | Total | 60666 |
Hierarchical Count dataframe
This dataframe represents number of sub-classes (like City) accross Country and State under study.
=build_hierarchical_dataframe(data, levels, value_column, metric="count")
df_count df_count
id | parent | value | |
---|---|---|---|
0 | A1 | INMP | 1 |
1 | B2 | INCG | 1 |
2 | C2 | USNY | 1 |
3 | D1 | INSD | 1 |
4 | E2 | INGD | 1 |
5 | F1 | USSF | 1 |
6 | G2 | INSA | 1 |
7 | H1 | INAS | 1 |
8 | I2 | USHF | 1 |
9 | J1 | INSR | 1 |
10 | K2 | INCQ | 1 |
11 | L3 | USSF | 1 |
12 | M1 | INCG | 1 |
13 | N1 | USNY | 1 |
14 | O2 | USNY | 1 |
15 | P1 | INSD | 1 |
16 | Q1 | USXO | 1 |
17 | R1 | USXO | 1 |
18 | S1 | INGD | 1 |
19 | INAS | India | 1 |
20 | INCG | India | 2 |
21 | INCQ | India | 1 |
22 | INGD | India | 2 |
23 | INMP | India | 1 |
24 | INSA | India | 1 |
25 | INSD | India | 2 |
26 | INSR | India | 1 |
27 | USHF | USA | 1 |
28 | USNY | USA | 3 |
29 | USSF | USA | 2 |
30 | USXO | USA | 2 |
31 | India | Total | 11 |
32 | USA | Total | 8 |
33 | Total | 19 |
(D) Visualizations
Now we would see the two most common ways of plotting sunburst charts in python. The user can choose any of the following modules,
- Plotly Express
- Plotly Graph Objects
Both of these modules generate same “figure object”. Just the difference comes in code syntax and in flexibility of modifying graph as required. Plotly express is more of generating plot by calling function from already defined set of parameters. One may be more comfortable in tweaking the details while working with graph objects. However, the beauty of plotly is that you are able do the same things in the figure generated from plotly express as those are possible in that with graph objects.
We will be using both of them, and generate the plots for the datasets generated in the above section.
from io import StringIO
from IPython.display import display_html, HTML
(D.1.) Plotly Express
import plotly.express as px
= px.sunburst(data, path=['Country', 'State', 'City'], values='Population')
figure =dict(t=10, b=10, r=10, l=10))
figure.update_layout(margin
figure.show() # HTML(figure.to_html(include_plotlyjs='cdn'))
(D.2.) Graph Objects
import plotly.graph_objects as go
= go.Figure()
figure
figure.add_trace(go.Sunburst(=df_sum['id'],
labels=df_sum['parent'],
parents=df_sum['value'],
values='total',
branchvalues=dict(colorscale='Rdbu'),
marker='<b> Country : %{label} </b> <br> Count : %{value} <extra>Population</extra>',
hovertemplate=2)
maxdepth
)=dict(t=10, b=10, r=10, l=10))
figure.update_layout(margin
figure.show() # HTML(figure.to_html(include_plotlyjs='cdn'))
(E) Communicating Plots with JSON
Plotly has in-built function to save figure as json : write_json(). Following cells show how to write and regenerate the plots.
"../data/Sunburst_Chart.json") figure.write_json(
import json
= open("../data/Sunburst_Chart.json")
opened_file = json.load(opened_file)
opened_fig
= go.Figure(
fig_ = opened_fig['data'],
data = opened_fig['layout']
layout
)
fig_.show()# HTML(fig_.to_html())
(F) Custom Plots
In the final section we would see the go.Figure subplots, where fully customize the plots.
from plotly.subplots import make_subplots
= make_subplots(1, 2, specs=[[{"type": "domain"}, {"type": "domain"}]],)
fig
fig.add_trace(go.Sunburst(=df_sum['id'],
labels=df_sum['parent'],
parents=df_sum['value'],
values='total',
branchvalues=dict(colorscale='sunset'),
marker='<b> Country : %{label} </b> <br> Count : %{value} <extra>Population</extra>',
hovertemplate=2), 1, 1)
maxdepth
fig.add_trace(go.Sunburst(=df_count['id'],
labels=df_count['parent'],
parents=df_count['value'],
values='total',
branchvalues=dict(colorscale='viridis'),
marker='<b> Country : %{label} </b> <br> Count : %{value} <extra>Cities</extra>',
hovertemplate=2), 1, 2)
maxdepth
=dict(t=10, b=10, r=10, l=10))
fig.update_layout(margin
fig.show()# HTML(fig.to_html())
Thank you!
References :