I'm trying to create a scrollable multiplot based on the answer to this question:
Creating a scrollable multiplot with python's pylab
Lines created using ax.plot()
are updating correctly, however I'm unable to figure out how to update artists created using xvlines()
and fill_between()
.
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from matplotlib.widgets import Slider#create dataframes
dfs={}
for x in range(100):col1=np.random.normal(10,0.5,30)col2=(np.repeat([5,8,7],np.round(np.random.dirichlet(np.ones(3),size=1)*31)[0].tolist()))[:30]col3=np.random.randint(4,size=30)dfs[x]=pd.DataFrame({'col1':col1,'col2':col2,'col3':col3})#create figure,axis,subplot
fig = plt.figure()
gs = gridspec.GridSpec(1,1,hspace=0,wspace=0,left=0.1,bottom=0.1)
ax = plt.subplot(gs[0])
ax.set_ylim([0,12])#slider
frame=0
axframe = plt.axes([0.13, 0.02, 0.75, 0.03])
sframe = Slider(axframe, 'frame', 0, 99, valinit=0,valfmt='%d')#plots
ln1,=ax.plot(dfs[0].index,dfs[0]['col1'])
ln2,=ax.plot(dfs[0].index,dfs[0]['col2'],c='black')#artists
ax.fill_between(dfs[0].index,y1=dfs[0]['col2']-0.5,y2=dfs[0]['col2']+0.5,where=dfs[0]['col2']==5,facecolor='r',edgecolors='none',alpha=0.5)
ax.fill_between(dfs[0].index,y1=dfs[0]['col2']-0.5,y2=dfs[0]['col2']+0.5,where=dfs[0]['col2']==8,facecolor='b',edgecolors='none',alpha=0.5)
ax.fill_between(dfs[0].index,y1=dfs[0]['col2']-0.5,y2=dfs[0]['col2']+0.5,where=dfs[0]['col2']==7,facecolor='g',edgecolors='none',alpha=0.5)
ax.vlines(x=dfs[0]['col3'].index,ymin=0,ymax=dfs[0]['col3'],color='black')#update plots
def update(val):frame = np.floor(sframe.val)ln1.set_ydata(dfs[frame]['col1'])ln2.set_ydata(dfs[frame]['col2'])ax.set_title('Frame ' + str(int(frame)))plt.draw()#connect callback to slider
sframe.on_changed(update)
plt.show()
This is what it looks like at the moment
I can't apply the same approach as for plot()
, since the following produces an error message:
ln3,=ax.fill_between(dfs[0].index,y1=dfs[0]['col2']-0.5,y2=dfs[0]['col2']+0.5,where=dfs[0]['col2']==5,facecolor='r',edgecolors='none',alpha=0.5)
TypeError: 'PolyCollection' object is not iterable
This is what it's meant to look like on each frame
fill_between
returns a PolyCollection
, which expects a list (or several lists) of vertices upon creation. Unfortunately I haven't found a way to retrieve the vertices that where used to create the given PolyCollection
, but in your case it is easy enough to create the PolyCollection
directly (thereby avoiding the use of fill_between
) and then update its vertices upon frame change.
Below a version of your code that does what you are after:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from matplotlib.widgets import Sliderfrom matplotlib.collections import PolyCollection#create dataframes
dfs={}
for x in range(100):col1=np.random.normal(10,0.5,30)col2=(np.repeat([5,8,7],np.round(np.random.dirichlet(np.ones(3),size=1)*31)[0].tolist()))[:30]col3=np.random.randint(4,size=30)dfs[x]=pd.DataFrame({'col1':col1,'col2':col2,'col3':col3})#create figure,axis,subplot
fig = plt.figure()
gs = gridspec.GridSpec(1,1,hspace=0,wspace=0,left=0.1,bottom=0.1)
ax = plt.subplot(gs[0])
ax.set_ylim([0,12])#slider
frame=0
axframe = plt.axes([0.13, 0.02, 0.75, 0.03])
sframe = Slider(axframe, 'frame', 0, 99, valinit=0,valfmt='%d')#plots
ln1,=ax.plot(dfs[0].index,dfs[0]['col1'])
ln2,=ax.plot(dfs[0].index,dfs[0]['col2'],c='black')##additional code to update the PolyCollections
val_r = 5
val_b = 8
val_g = 7def update_collection(collection, value, frame = 0):xs = np.array(dfs[frame].index)ys = np.array(dfs[frame]['col2'])##we need to catch the case where no points with y == value exist:try:minx = np.min(xs[ys == value])maxx = np.max(xs[ys == value])miny = value-0.5maxy = value+0.5verts = np.array([[minx,miny],[maxx,miny],[maxx,maxy],[minx,maxy]])except ValueError:verts = np.zeros((0,2))finally:collection.set_verts([verts])#artists##ax.fill_between(dfs[0].index,y1=dfs[0]['col2']-0.5,y2=dfs[0]['col2']+0.5,where=dfs[0]['col2']==5,facecolor='r',edgecolors='none',alpha=0.5)
reds = PolyCollection([],facecolors = ['r'], alpha = 0.5)
ax.add_collection(reds)
update_collection(reds,val_r)##ax.fill_between(dfs[0].index,y1=dfs[0]['col2']-0.5,y2=dfs[0]['col2']+0.5,where=dfs[0]['col2']==8,facecolor='b',edgecolors='none',alpha=0.5)
blues = PolyCollection([],facecolors = ['b'], alpha = 0.5)
ax.add_collection(blues)
update_collection(blues, val_b)##ax.fill_between(dfs[0].index,y1=dfs[0]['col2']-0.5,y2=dfs[0]['col2']+0.5,where=dfs[0]['col2']==7,facecolor='g',edgecolors='none',alpha=0.5)
greens = PolyCollection([],facecolors = ['g'], alpha = 0.5)
ax.add_collection(greens)
update_collection(greens, val_g)ax.vlines(x=dfs[0]['col3'].index,ymin=0,ymax=dfs[0]['col3'],color='black')#update plots
def update(val):frame = np.floor(sframe.val)ln1.set_ydata(dfs[frame]['col1'])ln2.set_ydata(dfs[frame]['col2'])ax.set_title('Frame ' + str(int(frame)))##updating the PolyCollections:update_collection(reds,val_r, frame)update_collection(blues,val_b, frame)update_collection(greens,val_g, frame)plt.draw()#connect callback to slider
sframe.on_changed(update)
plt.show()
Each of the three PolyCollections
(reds
, blues
, and greens
) has only four vertices (the edges of the rectangles), which are determined based on the given data (which is done in update_collections
). The result looks like this:
Tested in Python 3.5