Animating a network graph to show the progress of an algorithm

2024/9/22 23:24:30

I would like to animate a network graph to show the progress of an algorithm. I am using NetworkX for graph creation.

From this SO answer, I came up with a solution using clear_ouput from IPython.display and the command plt.pause() to manage the speed of the animation. This works well for small graphs with a few nodes but when I implement on a 10x10 grid, the animation is very slow and reducing the argument in plt.pause() does not seem to have any effect on the animation speed. Here is a MME with an implementation of Dijkstra's algorithm where I update the colors of the nodes at each iteration of the algorithm:

import math
import queue
import random
import networkx as nx
import matplotlib.pyplot as plt
from IPython.display import clear_output
%matplotlib inline# plotting function
def get_fig(G,current,pred): nColorList = []for i in G.nodes():        if i == current: nColorList.append('red')elif i==pred: nColorList.append('white')elif i==N: nColorList.append('grey')        elif node_visited[i]==1:nColorList.append('dodgerblue')else: nColorList.append('powderblue')plt.figure(figsize=(10,10))nx.draw_networkx(G,pos,node_color=nColorList,width=2,node_size=400,font_size=10)plt.axis('off')plt.show()# graph creation
G=nx.DiGraph()
pos={}
cost={}
for i in range(100):x= i % 10y= math.floor(i/10)pos[i]=(x,y)    if i % 10 != 9 and i+1 < 100: cost[(i,i+1)] = random.randint(0,9)cost[(i+1,i)] = random.randint(0,9)if i+10 < 100: cost[(i,i+10)] = random.randint(0,9)cost[(i+10,i)] = random.randint(0,9)
G.add_edges_from(cost)   # algorithm initialization
lab={}
path={}
node_visited={}
N = random.randint(0,99)
SE = queue.PriorityQueue()
SE.put((0,N))
for i in G.nodes():       if i == N: lab[i] = 0        else: lab[i] = 9999path[i] = Nonenode_visited[i] = 0 # algorithm main loop    
while not SE.empty():(l,j) = SE.get()    if node_visited[j]==1: continuenode_visited[j] = 1for i in G.predecessors(j):        insert_in_SE = 0               if lab[i] > cost[(i,j)] + lab[j]:lab[i] = cost[(i,j)] + lab[j]path[i] = jSE.put((lab[i],i))clear_output(wait=True)         get_fig(G,j,i)plt.pause(0.0001)
print('end')

Ideally I would like to show the whole animation in no more than 5 seconds, whereas it currently takes a few minutes to complete the algorithm, which suggests that plt.pause(0.0001) does not work as intended.

After reading SO posts on graph animation (post 2 and post 3), it seems that the animation module from matplotlib could be used to resolve this but I have not been able to successfully implement the answers in my algorithm. The answer in post 2 suggests the use of FuncAnimation from matplotlib but I am struggling to adapt the update method to my problem and the answer in post 3 leads to a nice tutorial with a similar suggestion.

My question is how can I improve the speed of the animation for my problem: is it possible to arrange the clear_output and plt.pause() commands for faster animation or should I use FuncAnimation from matplotlib? If it's the latter, then how should I define the update function?

Thank you for your help.

EDIT 1

import math
import queue
import random
import networkx as nx
import matplotlib.pyplot as plt# plotting function
def get_fig(G,current,pred):   for i in G.nodes():        if i==current: G.node[i]['draw'].set_color('red')            elif i==pred: G.node[i]['draw'].set_color('white')elif i==N: G.node[i]['draw'].set_color('grey')        elif node_visited[i]==1: G.node[i]['draw'].set_color('dodgerblue')else: G.node[i]['draw'].set_color('powderblue')    # graph creation
G=nx.DiGraph()
pos={}
cost={}
for i in range(100):x= i % 10y= math.floor(i/10)pos[i]=(x,y)    if i % 10 != 9 and i+1 < 100: cost[(i,i+1)] = random.randint(0,9)cost[(i+1,i)] = random.randint(0,9)if i+10 < 100: cost[(i,i+10)] = random.randint(0,9)cost[(i+10,i)] = random.randint(0,9)
G.add_edges_from(cost)# algorithm initialization
plt.figure(1, figsize=(10,10))
lab={}
path={}
node_visited={}
N = random.randint(0,99)
SE = queue.PriorityQueue()
SE.put((0,N))
for i in G.nodes():       if i == N: lab[i] = 0        else: lab[i] = 9999path[i] = Nonenode_visited[i] = 0 G.node[i]['draw'] = nx.draw_networkx_nodes(G,pos,nodelist=[i],node_size=400,alpha=1,with_labels=True,node_color='powderblue')
for i,j in G.edges():G[i][j]['draw']=nx.draw_networkx_edges(G,pos,edgelist=[(i,j)],width=2)    plt.ion()
plt.draw()
plt.show()# algorithm main loop  
while not SE.empty():(l,j) = SE.get()    if node_visited[j]==1: continuenode_visited[j] = 1for i in G.predecessors(j):        insert_in_SE = 0               if lab[i] > cost[(i,j)] + lab[j]:lab[i] = cost[(i,j)] + lab[j]path[i] = jSE.put((lab[i],i))       get_fig(G,j,i)        plt.draw()plt.pause(0.00001)
plt.close()

EDIT 2

import math
import queue
import random
import networkx as nx
import matplotlib.pyplot as plt# graph creation
G=nx.DiGraph()
pos={}
cost={}
for i in range(100):x= i % 10y= math.floor(i/10)pos[i]=(x,y)    if i % 10 != 9 and i+1 < 100: cost[(i,i+1)] = random.randint(0,9)cost[(i+1,i)] = random.randint(0,9)if i+10 < 100: cost[(i,i+10)] = random.randint(0,9)cost[(i+10,i)] = random.randint(0,9)
G.add_edges_from(cost)# algorithm initialization
lab={}
path={}
node_visited={}
N = random.randint(0,99)
SE = queue.PriorityQueue()
SE.put((0,N))
cf = plt.figure(1, figsize=(10,10))    
ax = cf.add_axes((0,0,1,1))
for i in G.nodes():       if i == N: lab[i] = 0G.node[i]['draw'] = nx.draw_networkx_nodes(G,pos,nodelist=[i],node_size=400,alpha=1.0,node_color='grey')else: lab[i] = 9999G.node[i]['draw'] = nx.draw_networkx_nodes(G,pos,nodelist=[i],node_size=400,alpha=0.2,node_color='dodgerblue')path[i] = Nonenode_visited[i] = 0
for i,j in G.edges():G[i][j]['draw']=nx.draw_networkx_edges(G,pos,edgelist=[(i,j)],width=3,alpha=0.2,arrows=False)plt.ion()
plt.show()
ax = plt.gca()
canvas = ax.figure.canvas
background = canvas.copy_from_bbox(ax.bbox)# algorithm main loop  
while not SE.empty():(l,j) = SE.get()    if node_visited[j]==1: continuenode_visited[j] = 1if j!=N:G.node[j]['draw'].set_color('r')        for i in G.predecessors(j):        insert_in_SE = 0               if lab[i] > cost[(i,j)] + lab[j]:lab[i] = cost[(i,j)] + lab[j]path[i] = jSE.put((lab[i],i))if i!=N:            G.node[i]['draw'].set_alpha(0.7)G[i][j]['draw'].set_alpha(1.0)ax.draw_artist(G[i][j]['draw'])ax.draw_artist(G.node[i]['draw'])ax.draw_artist(G.node[j]['draw'])canvas.blit(ax.bbox)    plt.pause(0.0001)
plt.close()
Answer

If your graph isn't too big you could try the following approach that sets the properties for individual nodes and edges. The trick is to save the output of the drawing functions which gives you a handle to the object properties like color, transparency, and visibility.

import networkx as nx
import matplotlib.pyplot as pltG = nx.cycle_graph(12)
pos = nx.spring_layout(G)cf = plt.figure(1, figsize=(8,8))
ax = cf.add_axes((0,0,1,1))for n in G:G.node[n]['draw'] = nx.draw_networkx_nodes(G,pos,nodelist=[n], with_labels=False,node_size=200,alpha=0.5,node_color='r')
for u,v in G.edges():G[u][v]['draw']=nx.draw_networkx_edges(G,pos,edgelist=[(u,v)],alpha=0.5,arrows=False,width=5)plt.ion()
plt.draw()sp = nx.shortest_path(G,0,6)
edges = zip(sp[:-1],sp[1:])for u,v in edges:plt.pause(1)G.node[u]['draw'].set_color('r')G.node[v]['draw'].set_color('r')G[u][v]['draw'].set_alpha(1.0)G[u][v]['draw'].set_color('r')plt.draw()

EDIT

Here is an example on a 10x10 grid using graphviz to do the layout. The whole thing runs in about 1 second on my machine.

import networkx as nx
import matplotlib.pyplot as pltG = nx.grid_2d_graph(10,10)
pos = nx.graphviz_layout(G)cf = plt.figure(1, figsize=(8,8))
ax = cf.add_axes((0,0,1,1))for n in G:G.node[n]['draw'] = nx.draw_networkx_nodes(G,pos,nodelist=[n], with_labels=False,node_size=200,alpha=0.5,node_color='k')
for u,v in G.edges():G[u][v]['draw']=nx.draw_networkx_edges(G,pos,edgelist=[(u,v)],alpha=0.5,arrows=False,width=5)plt.ion()
plt.draw()
plt.show()
sp = nx.shortest_path(G,(0,0),(9,9))
edges = zip(sp[:-1],sp[1:])for u,v in edges:G.node[u]['draw'].set_color('r')G.node[v]['draw'].set_color('r')G[u][v]['draw'].set_alpha(1.0)G[u][v]['draw'].set_color('r')plt.draw()

EDIT 2

Here is another approach that is faster (doesn't redraw axis or all nodes) and uses a breadth first search algorithm. This one runs in about 2 seconds on my machine. I noticed that some backends are faster - I'm using GTKAgg.

import networkx as nx
import matplotlib.pyplot as pltdef single_source_shortest_path(G,source):ax = plt.gca()canvas = ax.figure.canvasbackground = canvas.copy_from_bbox(ax.bbox)level=0                  # the current levelnextlevel={source:1}       # list of nodes to check at next levelpaths={source:[source]}  # paths dictionary  (paths to key from source)G.node[source]['draw'].set_color('r')G.node[source]['draw'].set_alpha('1.0')while nextlevel:thislevel=nextlevelnextlevel={}for v in thislevel:
#            canvas.restore_region(background)s = G.node[v]['draw']s.set_color('r')s.set_alpha('1.0')for w in G[v]:if w not in paths:n = G.node[w]['draw']n.set_color('r')n.set_alpha('1.0')e = G[v][w]['draw']e.set_alpha(1.0)e.set_color('k')ax.draw_artist(e)ax.draw_artist(n)ax.draw_artist(s)paths[w]=paths[v]+[w]nextlevel[w]=1canvas.blit(ax.bbox)level=level+1return pathsif __name__=='__main__':G = nx.grid_2d_graph(10,10)pos = nx.graphviz_layout(G)cf = plt.figure(1, figsize=(8,8))ax = cf.add_axes((0,0,1,1))for n in G:G.node[n]['draw'] = nx.draw_networkx_nodes(G,pos,nodelist=[n], with_labels=False,node_size=200,alpha=0.2,node_color='k')for u,v in G.edges():G[u][v]['draw']=nx.draw_networkx_edges(G,pos,edgelist=[(u,v)],alpha=0.5,arrows=False,width=5)plt.ion()plt.show()path = single_source_shortest_path(G,source=(0,0))
https://en.xdnf.cn/q/71891.html

Related Q&A

How to run grpc on ipv4 only

Im going to run a grpc server on IPv4 address like this: server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) protoc_pb2_grpc.add_ProtocServicer_to_server(StockProtocServicer(), server) ser…

Python/PyCharm mark unused import as used

I need to import a resource_rc.py file in my module. It is immediately marked by PyCharm as "unused". Is there a way to mark "unused" imports and also variables, etc. as used in Pyt…

Replacing every 2nd element in the list

I got a 2 dimensional list:[[5, 80, 2, 57, 5, 97], [2, 78, 2, 56, 6, 62], [5, 34, 3, 54, 6, 5, 2, 58, 5, 61, 5, 16]]In which I need to change every second element to 0, starting from first one. So it s…

Are C++-style internal typedefs possible in Cython?

In C++ its possible to declare type aliases that are members of a class or struct:struct Foo {// internal type aliastypedef int DataType;// ... };Is there any way to do the same thing in Cython? Ive t…

How do I use a regular expression to match a name?

I am a newbie in Python. I want to write a regular expression for some name checking. My input string can contain a-z, A-Z, 0-9, and _ , but it should start with either a-z or A-Z (not 0-9 and _ ). I…

python - multiprocessing module

Heres what I am trying to accomplish - I have about a million files which I need to parse & append the parsed content to a single file. Since a single process takes ages, this option is out. Not us…

How to make VSCode always run main.py

I am writing my first library in Python, When developing I want my run code button in VS Code to always start running the code from the main.py file in the root directory. I have added a new configurat…

Why does tesseract fail to read text off this simple image?

I have read mountains of posts on pytesseract, but I cannot get it to read text off a dead simple image; It returns an empty string.Here is the image:I have tried scaling it, grayscaling it, and adjust…

python click subcommand unified error handling

In the case where there are command groups and every sub-command may raise exceptions, how can I handle them all together in one place?Given the example below:import click@click.group() def cli():pass…

Data structure for large ranges of consecutive integers?

Suppose you have a large range of consecutive integers in memory, each of which belongs to exactly one category. Two operations must be O(log n): moving a range from one category to another, and findin…