Interactive Sunburst Charts - Plotly

code
analysis
Author

Pratik Kumar

Published

June 19, 2023

(A) Introduction

Data visualization plays a vital role in various domains such as data analytics, data science, data dashboarding, and exploratory/statistical analysis. Within the Python and R ecosystems, there are several popular visualization libraries commonly used. These include:

Among these, the widely used library is the Plotly Graphing Library, which offers libraries in multiple languages, high-quality scientific/non-scientific graphs, and easily shareable interactive plots.

In this post, I will be discussing an intriguing plot called the Sunburst Chart. Sunburst charts provide an interactive visualization of layered information, allowing for an enhanced understanding of complex data structures.

(B) Sunburst Chart

A sunburst chart is a powerful visualization tool used to represent hierarchical datasets. In a hierarchical dataset, there exists a parent-child relationship among the features or variables, resembling a tree-like structure. To generate a sunburst plot using Plotly, you can leverage the capabilities of either plotly.express or plotly.graph_objects libraries.

Let’s consider an example dataframe (dummy data for demonstration purposes) with a tree-like structure, where the columns or features exhibit parent-child relationships with other columns.

General Dataset: This dataframe contains classes and values organized in columns, as depicted in the sample data provided. Sunburst DataFrame: This hierarchical dataframe defines the logical parent-child relationships between columns and their corresponding values.

Now, let’s delve into how this data would appear by visualizing it using a sunburst chart.

(C) Datasets

The following dataset is a dummy data for demonstration. Usually, you may come accross, this kind of data while working on a data science/analytics projects.

#Importing pandas to handle dataframe
import pandas as pd
# Suppress pandas warnings
import warnings
warnings.filterwarnings("ignore")
data = pd.read_csv("../data/dummy_data.csv")
data.head()
Country State City Population
0 India INMP A1 512
1 India INCG B2 12201
2 India INCG M1 9021
3 USA USNY C2 812
4 USA USNY N1 821

The dataset is not in hierachical form. The sunburst chart needs a parent, child and value variable for generating the plot. Hence, we need to convert the table into a ‘chart-acceptable’ format. The following function performs the job. The function is modified version of original function defined at Plotly’s documentation, to know more about this please visit here.

def build_hierarchical_dataframe(df, levels, value_column, metric):
    """
    Build a hierarchy of levels for Sunburst.
    - Levels are given starting from the bottom to the top of the hierarchy,
    i.e. the last level corresponds to the root.
    - Input : 
        - df : pandas dataframe 
        - levels : list of column names in the order, child to root.
        - value_column : string value corresponding to value of column to display in chart.
        - metric : string value equal to "sum" or "count".
    - Output:
        - df_all_trees : pandas dataframe for sunburst with columns, ['id', 'parent', 'value'].  
    """
    df_all_trees = pd.DataFrame(columns=['id', 'parent', 'value'])
    
    for i, level in enumerate(levels):
        df_tree = pd.DataFrame(columns=['id', 'parent', 'value'])
        ## Groupby based upon metric chosen
        if metric=="count":
            dfg = df.groupby(levels[i:]).count()
        else:
            dfg = df.groupby(levels[i:]).sum()
        
        dfg = dfg.reset_index()
        df_tree['id'] = dfg[level].copy()

        ## Set parent of the levels 
        if i < len(levels) - 1:
            df_tree['parent'] = dfg[levels[i+1]].copy()
        else:
            df_tree['parent'] = 'Total'
        
        df_tree['value'] = dfg[value_column]
        df_all_trees = pd.concat([df_all_trees, df_tree], ignore_index=True)
    
    ## Value calculation for parent 
    if metric=="count":
        total = pd.Series(dict(id='Total', parent='', value=df[value_column].count()))
    else:
        total = pd.Series(dict(id='Total', parent='', value=df[value_column].sum()))
    
    ## Add frames one below the other to form the final dataframe
    df_all_trees = pd.concat([df_all_trees, pd.DataFrame([total])], ignore_index=True)
    return df_all_trees
levels = ['City', 'State', 'Country'] 
value_column = 'Population'
metric = "sum"

Hierarchical Sum dataframe

This dataframe represents total population accross Country, State and City under study.

df_sum=build_hierarchical_dataframe(data, levels, value_column, metric="sum")
df_sum
id parent value
0 A1 INMP 512
1 B2 INCG 12201
2 C2 USNY 812
3 D1 INSD 9104
4 E2 INGD 132
5 F1 USSF 82
6 G2 INSA 5121
7 H1 INAS 1232
8 I2 USHF 8841
9 J1 INSR 11
10 K2 INCQ 1236
11 L3 USSF 1200
12 M1 INCG 9021
13 N1 USNY 821
14 O2 USNY 128
15 P1 INSD 20
16 Q1 USXO 4120
17 R1 USXO 60
18 S1 INGD 6012
19 INAS India 1232
20 INCG India 21222
21 INCQ India 1236
22 INGD India 6144
23 INMP India 512
24 INSA India 5121
25 INSD India 9124
26 INSR India 11
27 USHF USA 8841
28 USNY USA 1761
29 USSF USA 1282
30 USXO USA 4180
31 India Total 44602
32 USA Total 16064
33 Total 60666

Hierarchical Count dataframe

This dataframe represents number of sub-classes (like City) accross Country and State under study.

df_count=build_hierarchical_dataframe(data, levels, value_column, metric="count")
df_count
id parent value
0 A1 INMP 1
1 B2 INCG 1
2 C2 USNY 1
3 D1 INSD 1
4 E2 INGD 1
5 F1 USSF 1
6 G2 INSA 1
7 H1 INAS 1
8 I2 USHF 1
9 J1 INSR 1
10 K2 INCQ 1
11 L3 USSF 1
12 M1 INCG 1
13 N1 USNY 1
14 O2 USNY 1
15 P1 INSD 1
16 Q1 USXO 1
17 R1 USXO 1
18 S1 INGD 1
19 INAS India 1
20 INCG India 2
21 INCQ India 1
22 INGD India 2
23 INMP India 1
24 INSA India 1
25 INSD India 2
26 INSR India 1
27 USHF USA 1
28 USNY USA 3
29 USSF USA 2
30 USXO USA 2
31 India Total 11
32 USA Total 8
33 Total 19

(D) Visualizations

Now we would see the two most common ways of plotting sunburst charts in python. The user can choose any of the following modules,

  1. Plotly Express
  2. Plotly Graph Objects

Both of these modules generate same “figure object”. Just the difference comes in code syntax and in flexibility of modifying graph as required. Plotly express is more of generating plot by calling function from already defined set of parameters. One may be more comfortable in tweaking the details while working with graph objects. However, the beauty of plotly is that you are able do the same things in the figure generated from plotly express as those are possible in that with graph objects.

We will be using both of them, and generate the plots for the datasets generated in the above section.

from io import StringIO
from IPython.display import display_html, HTML

(D.1.) Plotly Express

import plotly.express as px 

figure = px.sunburst(data, path=['Country', 'State', 'City'], values='Population')
figure.update_layout(margin=dict(t=10, b=10, r=10, l=10))
figure.show() 
# HTML(figure.to_html(include_plotlyjs='cdn'))

(D.2.) Graph Objects

import plotly.graph_objects as go

figure = go.Figure()
figure.add_trace(go.Sunburst(
        labels=df_sum['id'],
        parents=df_sum['parent'],
        values=df_sum['value'],
        branchvalues='total',
        marker=dict(colorscale='Rdbu'),
        hovertemplate='<b> Country : %{label} </b> <br> Count : %{value} <extra>Population</extra>',
        maxdepth=2)
    )
figure.update_layout(margin=dict(t=10, b=10, r=10, l=10))
figure.show() 
# HTML(figure.to_html(include_plotlyjs='cdn'))

(E) Communicating Plots with JSON

Plotly has in-built function to save figure as json : write_json(). Following cells show how to write and regenerate the plots.

figure.write_json("../data/Sunburst_Chart.json")
import json

opened_file = open("../data/Sunburst_Chart.json")
opened_fig = json.load(opened_file)

fig_ = go.Figure(
    data = opened_fig['data'],
    layout = opened_fig['layout']
    )
fig_.show()
# HTML(fig_.to_html()) 

(F) Custom Plots

In the final section we would see the go.Figure subplots, where fully customize the plots.

from plotly.subplots import make_subplots

fig = make_subplots(1, 2, specs=[[{"type": "domain"}, {"type": "domain"}]],)
fig.add_trace(go.Sunburst(
    labels=df_sum['id'],
    parents=df_sum['parent'],
    values=df_sum['value'],
    branchvalues='total',
    marker=dict(colorscale='sunset'),
    hovertemplate='<b> Country : %{label} </b> <br> Count : %{value} <extra>Population</extra>',
    maxdepth=2), 1, 1)

fig.add_trace(go.Sunburst(
    labels=df_count['id'],
    parents=df_count['parent'],
    values=df_count['value'],
    branchvalues='total',
    marker=dict(colorscale='viridis'),
    hovertemplate='<b> Country : %{label} </b> <br> Count : %{value} <extra>Cities</extra>',
    maxdepth=2), 1, 2)

fig.update_layout(margin=dict(t=10, b=10, r=10, l=10))
fig.show()
# HTML(fig.to_html()) 

Thank you!