I would like to animate a network graph to show the progress of an algorithm. I am using NetworkX for graph creation.
From this SO answer, I came up with a solution using clear_ouput
from IPython.display
and the command plt.pause()
to manage the speed of the animation. This works well for small graphs with a few nodes but when I implement on a 10x10 grid, the animation is very slow and reducing the argument in plt.pause()
does not seem to have any effect on the animation speed. Here is a MME with an implementation of Dijkstra's algorithm where I update the colors of the nodes at each iteration of the algorithm:
import math
import queue
import random
import networkx as nx
import matplotlib.pyplot as plt
from IPython.display import clear_output
%matplotlib inline# plotting function
def get_fig(G,current,pred): nColorList = []for i in G.nodes(): if i == current: nColorList.append('red')elif i==pred: nColorList.append('white')elif i==N: nColorList.append('grey') elif node_visited[i]==1:nColorList.append('dodgerblue')else: nColorList.append('powderblue')plt.figure(figsize=(10,10))nx.draw_networkx(G,pos,node_color=nColorList,width=2,node_size=400,font_size=10)plt.axis('off')plt.show()# graph creation
G=nx.DiGraph()
pos={}
cost={}
for i in range(100):x= i % 10y= math.floor(i/10)pos[i]=(x,y) if i % 10 != 9 and i+1 < 100: cost[(i,i+1)] = random.randint(0,9)cost[(i+1,i)] = random.randint(0,9)if i+10 < 100: cost[(i,i+10)] = random.randint(0,9)cost[(i+10,i)] = random.randint(0,9)
G.add_edges_from(cost) # algorithm initialization
lab={}
path={}
node_visited={}
N = random.randint(0,99)
SE = queue.PriorityQueue()
SE.put((0,N))
for i in G.nodes(): if i == N: lab[i] = 0 else: lab[i] = 9999path[i] = Nonenode_visited[i] = 0 # algorithm main loop
while not SE.empty():(l,j) = SE.get() if node_visited[j]==1: continuenode_visited[j] = 1for i in G.predecessors(j): insert_in_SE = 0 if lab[i] > cost[(i,j)] + lab[j]:lab[i] = cost[(i,j)] + lab[j]path[i] = jSE.put((lab[i],i))clear_output(wait=True) get_fig(G,j,i)plt.pause(0.0001)
print('end')
Ideally I would like to show the whole animation in no more than 5 seconds, whereas it currently takes a few minutes to complete the algorithm, which suggests that plt.pause(0.0001)
does not work as intended.
After reading SO posts on graph animation (post 2 and post 3), it seems that the animation
module from matplotlib could be used to resolve this but I have not been able to successfully implement the answers in my algorithm. The answer in post 2 suggests the use of FuncAnimation
from matplotlib but I am struggling to adapt the update
method to my problem and the answer in post 3 leads to a nice tutorial with a similar suggestion.
My question is how can I improve the speed of the animation for my problem: is it possible to arrange the clear_output
and plt.pause()
commands for faster animation or should I use FuncAnimation
from matplotlib? If it's the latter, then how should I define the update
function?
Thank you for your help.
EDIT 1
import math
import queue
import random
import networkx as nx
import matplotlib.pyplot as plt# plotting function
def get_fig(G,current,pred): for i in G.nodes(): if i==current: G.node[i]['draw'].set_color('red') elif i==pred: G.node[i]['draw'].set_color('white')elif i==N: G.node[i]['draw'].set_color('grey') elif node_visited[i]==1: G.node[i]['draw'].set_color('dodgerblue')else: G.node[i]['draw'].set_color('powderblue') # graph creation
G=nx.DiGraph()
pos={}
cost={}
for i in range(100):x= i % 10y= math.floor(i/10)pos[i]=(x,y) if i % 10 != 9 and i+1 < 100: cost[(i,i+1)] = random.randint(0,9)cost[(i+1,i)] = random.randint(0,9)if i+10 < 100: cost[(i,i+10)] = random.randint(0,9)cost[(i+10,i)] = random.randint(0,9)
G.add_edges_from(cost)# algorithm initialization
plt.figure(1, figsize=(10,10))
lab={}
path={}
node_visited={}
N = random.randint(0,99)
SE = queue.PriorityQueue()
SE.put((0,N))
for i in G.nodes(): if i == N: lab[i] = 0 else: lab[i] = 9999path[i] = Nonenode_visited[i] = 0 G.node[i]['draw'] = nx.draw_networkx_nodes(G,pos,nodelist=[i],node_size=400,alpha=1,with_labels=True,node_color='powderblue')
for i,j in G.edges():G[i][j]['draw']=nx.draw_networkx_edges(G,pos,edgelist=[(i,j)],width=2) plt.ion()
plt.draw()
plt.show()# algorithm main loop
while not SE.empty():(l,j) = SE.get() if node_visited[j]==1: continuenode_visited[j] = 1for i in G.predecessors(j): insert_in_SE = 0 if lab[i] > cost[(i,j)] + lab[j]:lab[i] = cost[(i,j)] + lab[j]path[i] = jSE.put((lab[i],i)) get_fig(G,j,i) plt.draw()plt.pause(0.00001)
plt.close()
EDIT 2
import math
import queue
import random
import networkx as nx
import matplotlib.pyplot as plt# graph creation
G=nx.DiGraph()
pos={}
cost={}
for i in range(100):x= i % 10y= math.floor(i/10)pos[i]=(x,y) if i % 10 != 9 and i+1 < 100: cost[(i,i+1)] = random.randint(0,9)cost[(i+1,i)] = random.randint(0,9)if i+10 < 100: cost[(i,i+10)] = random.randint(0,9)cost[(i+10,i)] = random.randint(0,9)
G.add_edges_from(cost)# algorithm initialization
lab={}
path={}
node_visited={}
N = random.randint(0,99)
SE = queue.PriorityQueue()
SE.put((0,N))
cf = plt.figure(1, figsize=(10,10))
ax = cf.add_axes((0,0,1,1))
for i in G.nodes(): if i == N: lab[i] = 0G.node[i]['draw'] = nx.draw_networkx_nodes(G,pos,nodelist=[i],node_size=400,alpha=1.0,node_color='grey')else: lab[i] = 9999G.node[i]['draw'] = nx.draw_networkx_nodes(G,pos,nodelist=[i],node_size=400,alpha=0.2,node_color='dodgerblue')path[i] = Nonenode_visited[i] = 0
for i,j in G.edges():G[i][j]['draw']=nx.draw_networkx_edges(G,pos,edgelist=[(i,j)],width=3,alpha=0.2,arrows=False)plt.ion()
plt.show()
ax = plt.gca()
canvas = ax.figure.canvas
background = canvas.copy_from_bbox(ax.bbox)# algorithm main loop
while not SE.empty():(l,j) = SE.get() if node_visited[j]==1: continuenode_visited[j] = 1if j!=N:G.node[j]['draw'].set_color('r') for i in G.predecessors(j): insert_in_SE = 0 if lab[i] > cost[(i,j)] + lab[j]:lab[i] = cost[(i,j)] + lab[j]path[i] = jSE.put((lab[i],i))if i!=N: G.node[i]['draw'].set_alpha(0.7)G[i][j]['draw'].set_alpha(1.0)ax.draw_artist(G[i][j]['draw'])ax.draw_artist(G.node[i]['draw'])ax.draw_artist(G.node[j]['draw'])canvas.blit(ax.bbox) plt.pause(0.0001)
plt.close()