Pandas side-by-side stacked bar plot

2024/10/15 17:19:30

I want to create a stacked bar plot of the titanic dataset. The plot needs to group by "Pclass", "Sex" and "Survived". I have managed to do this with a lot of tedious numpy manipulation to produce the normalized plot below (where "M" is male and "F" is female)enter image description here

Is there a way to do this using pandas inbuilt plotting functionality?

I have tried this:

import pandas as pd
import matplotlib.pyplot as plt
df = pd.read_csv('train.csv')
df_grouped = df.groupby(['Survived','Sex','Pclass'])['Survived'].count()
df_grouped.unstack().plot(kind='bar',stacked=True,  colormap='Blues', grid=True, figsize=(13,5));

enter image description here

Which is not what I want. Is there anyway to produce the first plot using pandas plotting? Thanks in advance

Answer

The resulting bars will not neighbour each other as in your first figure, but outside of that, pandas lets you do what you want as follows:

df_g = df.groupby(['Pclass', 'Sex'])['Survived'].agg([np.mean, lambda x: 1-np.mean(x)])
df_g.columns = ['Survived', 'Died']
df_g.plot.bar(stacked=True)

enter image description here

Here, the horizontal grouping of patches is complicated by the requirement of stacking. If, for instance, we only cared about the value of "Survived", pandas could take care of it out-of-the-box.

df.groupby(['Pclass', 'Sex'])['Survived'].mean().unstack().plot.bar()

enter image description here

If an ad hoc solution suffices for post-processing the plot, doing so is also not terribly complicated:

import numpy as np
from matplotlib import tickerdf_g = df.groupby(['Pclass', 'Sex'])['Survived'].agg([np.mean, lambda x: 1-np.mean(x)])
df_g.columns = ['Survived', 'Died']
ax = df_g.plot.bar(stacked=True)# Move back every second patch
for i in range(6):new_x = ax.patches[i].get_x() - (i%2)/2ax.patches[i].set_x(new_x)ax.patches[i+6].set_x(new_x)# Update tick locations correspondingly
minor_tick_locs = [x.get_x()+1/4 for x in ax.patches[:6]]
major_tick_locs = np.array([x.get_x()+1/4 for x in ax.patches[:6]]).reshape(3, 2).mean(axis=1)
ax.set_xticks(minor_tick_locs, minor=True)
ax.set_xticks(major_tick_locs)# Use indices from dataframe as tick labels
minor_tick_labels = df_g.index.levels[1][df_g.index.labels[1]].values
major_tick_labels = df_g.index.levels[0].values
ax.xaxis.set_ticklabels(minor_tick_labels, minor=True)
ax.xaxis.set_ticklabels(major_tick_labels)# Remove ticks and organize tick labels to avoid overlap
ax.tick_params(axis='x', which='both', bottom='off')
ax.tick_params(axis='x', which='minor', rotation=45)
ax.tick_params(axis='x', which='major', pad=35, rotation=0)

enter image description here

https://en.xdnf.cn/q/69259.html

Related Q&A

Turn off list reflection in Numba

Im trying to accelerate my code using Numba. One of the arguments Im passing into the function is a mutable list of lists. When I try changing one of the sublists, I get this error: Failed in nopython …

Identifying large bodies of text via BeautifulSoup or other python based extractors

Given some random news article, I want to write a web crawler to find the largest body of text present, and extract it. The intention is to extract the physical news article on the page. The original p…

Pandas reindex and interpolate time series efficiently (reindex drops data)

Suppose I wish to re-index, with linear interpolation, a time series to a pre-defined index, where none of the index values are shared between old and new index. For example# index is all precise times…

How do you set the box width in a plotly box in python?

I currently have the following;y = time_h time_box = Box(y=y,name=Time (hours),boxmean=True,marker=Marker(color=green),boxpoints=all,jitter=0.5,pointpos=-2.0 ) layout = Layout(title=Time Box, ) fig = F…

how do you install django older version using easy_install?

I just broke my environment because of django 1.3. None of my sites are able to run. So, i decided to use virtualenv to set virtual environment with different python version as well as django.But, seem…

Whats difference between findall() and iterfind() of xml.etree.ElementTree

I write a program using just like belowfrom xml.etree.ElementTree import ETxmlroot = ET.fromstring([my xml content])for element in xmlroot.iterfind(".//mytag"):do some thingit works fine on m…

How to convert string dataframe column to datetime as format with year and week?

Sample Data:Week Price 2011-31 1.58 2011-32 1.9 2011-33 1.9 2011-34 1.9I have a dataframe like above and I wanna convert Week column type from string to datetime.My Code:data[Date_Time…

Tensorflow - ValueError: Shape must be rank 1 but is rank 0 for ParseExample/ParseExample

I have a .tfrecords file of the Ubuntu Dialog Corpus. I am trying to read in the whole dataset so that I can split the contexts and utterances into batches. Using tf.parse_single_example I was able to …

Navigating Multi-Dimensional JSON arrays in Python

Im trying to figure out how to query a JSON array in Python. Could someone show me how to do a simple search and print through a fairly complex array please?The example Im using is here: http://eu.bat…

Numpy, apply a list of functions along array dimension

I have a list of functions of the type:func_list = [lambda x: function1(input),lambda x: function2(input),lambda x: function3(input),lambda x: x]and an array of shape [4, 200, 200, 1] (a batch of image…