matplotlib does not show legend in scatter plot

2024/9/8 10:42:24

I am trying to work on a clustering problem for which I need to plot a scatter plot for my clusters.

%matplotlib inline
import matplotlib.pyplot as plt
df = pd.merge(dataframe,actual_cluster)
plt.scatter(df['x'], df['y'], c=df['cluster'])
plt.legend()
plt.show()

df['cluster'] is the actual cluster number. So I want that to be my color code.

enter image description here

It shows me a plot but it does not show me the legend. it does not give me error as well.

Am I doing something wrong?

Answer

EDIT:

Generating some random data:

from scipy.cluster.vq import kmeans2
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as snsn_clusters = 10
df = pd.DataFrame({'x':np.random.randn(1000), 'y':np.random.randn(1000)})
_, df['cluster'] = kmeans2(df, n_clusters)

Update

  • Use seaborn.relplot with kind='scatter' or use seaborn.scatterplot
    • Specify hue='cluster'
# figure level plot
sns.relplot(data=df, x='x', y='y', hue='cluster', palette='tab10', kind='scatter')

enter image description here

# axes level plot
fig, axes = plt.subplots(figsize=(6, 6))
sns.scatterplot(data=df, x='x', y='y', hue='cluster', palette='tab10', ax=axes)
axes.legend(loc='center left', bbox_to_anchor=(1, 0.5))

enter image description here

Original Answer

Plotting (matplotlib v3.3.4):

fig, ax = plt.subplots(figsize=(8, 6))
cmap = plt.cm.get_cmap('jet')
for i, cluster in df.groupby('cluster'):_ = ax.scatter(cluster['x'], cluster['y'], color=cmap(i/n_clusters), label=i, ec='k')
ax.legend(loc='center left', bbox_to_anchor=(1, 0.5))

Result:

enter image description here

Explanation:

Not going too much into nitty gritty details of matplotlib internals, plotting one cluster at a time sort of solves the issue. More specifically, ax.scatter() returns a PathCollection object which we are explicitly throwing away here but which seems to be passed internally to some sort of legend handler. Plotting all at once generates only one PathCollection/label pair, while plotting one cluster at a time generates n_clusters PathCollection/label pairs. You can see those objects by calling ax.get_legend_handles_labels() which returns something like:

([<matplotlib.collections.PathCollection at 0x7f60c2ff2ac8>,<matplotlib.collections.PathCollection at 0x7f60c2ff9d68>,<matplotlib.collections.PathCollection at 0x7f60c2ff9390>,<matplotlib.collections.PathCollection at 0x7f60c2f802e8>,<matplotlib.collections.PathCollection at 0x7f60c2f809b0>,<matplotlib.collections.PathCollection at 0x7f60c2ff9908>,<matplotlib.collections.PathCollection at 0x7f60c2f85668>,<matplotlib.collections.PathCollection at 0x7f60c2f8cc88>,<matplotlib.collections.PathCollection at 0x7f60c2f8c748>,<matplotlib.collections.PathCollection at 0x7f60c2f92d30>],['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'])

So actually ax.legend() is equivalent to ax.legend(*ax.get_legend_handles_labels()).

NOTES:

  1. If using Python 2, make sure i/n_clusters is a float

  2. Omitting fig, ax = plt.subplots() and using plt.<method> instead of ax.<method> works fine, but I always prefer to explicitly specify the Axes object I am using rather then implicitly use the "current axes" (plt.gca()).


OLD SIMPLE SOLUTION

In case you are ok with a colorbar (instead of discrete value labels), you can use Pandas built-in Matplotlib functionality:

df.plot.scatter('x', 'y', c='cluster', cmap='jet')

enter image description here

https://en.xdnf.cn/q/72569.html

Related Q&A

Correct usage of asyncio.Conditions wait_for() method

Im writing a project using Pythons asyncio module, and Id like to synchronize my tasks using its synchronization primitives. However, it doesnt seem to behave as Id expect.From the documentation, it se…

displaying newlines in the help message when using pythons optparse

Im using the optparse module for option/argument parsing. For backwards compatibility reasons, I cant use the argparse module. How can I format my epilog message so that newlines are preserved?In th…

How to use a learnable parameter in pytorch, constrained between 0 and 1?

I want to use a learnable parameter that only takes values between 0 and 1. How can I do this in pytorch? Currently I am using: self.beta = Parameter(torch.Tensor(1)) #initialize zeros(self.beta)But I…

generating a CSV file online on Google App Engine

I am using Google App Engine (python), I want my users to be able to download a CSV file generated using some data from the datastore (but I dont want them to download the whole thing, as I re-order th…

Python equivalence of Rs match() for indexing

So i essentially want to implement the equivalent of Rs match() function in Python, using Pandas dataframes - without using a for-loop. In R match() returns a vector of the positions of (first) matches…

Why doesnt Pydantic validate field assignments?

I want to use Pydantic to validate fields in my object, but it seems like validation only happens when I create an instance, but not when I later modify fields. from pydantic import BaseModel, validato…

Format OCR text annotation from Cloud Vision API in Python

I am using the Google Cloud Vision API for Python on a small program Im using. The function is working and I get the OCR results, but I need to format these before being able to work with them.This is …

Does pybtex support accent/special characters in .bib file?

from pybtex.database.input import bibtex parser = bibtex.Parser() bibdata = parser.parse_file("sample.bib")The above code snippet works really well in parsing a .bib file but it seems not to …

How do I count specific values across multiple columns in pandas

I have the DataFrame df = pd.DataFrame({colA:[?,2,3,4,?],colB:[1,2,?,3,4],colC:[?,2,3,4,5] })I would like to get the count the the number of ? in each column and return the following output - colA…

Split Python source into separate directories?

Here are some various Python packages my company "foo.com" uses:com.foo.bar.web com.foo.bar.lib com.foo.zig.web com.foo.zig.lib com.foo.zig.lib.lib1 com.foo.zig.lib.lib2Heres the traditional …