Pandas / Pyplot中的散点图:如何按类别绘制


89

我正在尝试使用Pandas DataFrame对象在pyplot中制作一个简单的散点图,但想要一种有效的方式来绘制两个变量,但要用第三列(键)来指定符号。我已经尝试过使用df.groupby的各种方法,但是没有成功。下面是一个示例df脚本。这会根据“ key1”为标记着色,但是我想看到带有“ key1”类别的图例。我靠近吗?谢谢。

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
df = pd.DataFrame(np.random.normal(10,1,30).reshape(10,3), index = pd.date_range('2010-01-01', freq = 'M', periods = 10), columns = ('one', 'two', 'three'))
df['key1'] = (4,4,4,6,6,6,8,8,8,8)
fig1 = plt.figure(1)
ax1 = fig1.add_subplot(111)
ax1.scatter(df['one'], df['two'], marker = 'o', c = df['key1'], alpha = 0.8)
plt.show()

Answers:


117

您可以scatter为此使用,但这需要您的数字key1,并且您不会注意到图例。

最好只plot对像这样的离散类别使用。例如:

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
np.random.seed(1974)

# Generate Data
num = 20
x, y = np.random.random((2, num))
labels = np.random.choice(['a', 'b', 'c'], num)
df = pd.DataFrame(dict(x=x, y=y, label=labels))

groups = df.groupby('label')

# Plot
fig, ax = plt.subplots()
ax.margins(0.05) # Optional, just adds 5% padding to the autoscaling
for name, group in groups:
    ax.plot(group.x, group.y, marker='o', linestyle='', ms=12, label=name)
ax.legend()

plt.show()

在此处输入图片说明

如果您希望外观看起来像默认pandas样式,则只需rcParams使用pandas样式表进行更新,并使用其颜色生成器即可。(我也略微调整了图例):

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
np.random.seed(1974)

# Generate Data
num = 20
x, y = np.random.random((2, num))
labels = np.random.choice(['a', 'b', 'c'], num)
df = pd.DataFrame(dict(x=x, y=y, label=labels))

groups = df.groupby('label')

# Plot
plt.rcParams.update(pd.tools.plotting.mpl_stylesheet)
colors = pd.tools.plotting._get_standard_colors(len(groups), color_type='random')

fig, ax = plt.subplots()
ax.set_color_cycle(colors)
ax.margins(0.05)
for name, group in groups:
    ax.plot(group.x, group.y, marker='o', linestyle='', ms=12, label=name)
ax.legend(numpoints=1, loc='upper left')

plt.show()

在此处输入图片说明


为什么在上面的RGB示例中,符号在图例中显示了两次?如何只显示一次?
史蒂夫·舒利斯特

1
@SteveSchulist-ax.legend(numpoints=1)用于仅显示一个标记。与a一样Line2D,有两个,通常会有一条线连接两个标记。
Joe Kington 2015年

此代码仅在命令后添加plt.hold(True)后才对我ax.plot()有用。知道为什么吗?
Yuval Atzmon

set_color_cycle() 在matplotlib 1.5中已弃用。还有 set_prop_cycle(),现在。
ale

52

使用Seabornpip install seaborn)作为单行代码,这很简单

sns.scatterplot(x_vars="one", y_vars="two", data=df, hue="key1")

import seaborn as sns
import pandas as pd
import numpy as np
np.random.seed(1974)

df = pd.DataFrame(
    np.random.normal(10, 1, 30).reshape(10, 3),
    index=pd.date_range('2010-01-01', freq='M', periods=10),
    columns=('one', 'two', 'three'))
df['key1'] = (4, 4, 4, 6, 6, 6, 8, 8, 8, 8)

sns.scatterplot(x="one", y="two", data=df, hue="key1")

在此处输入图片说明

这是供参考的数据框:

在此处输入图片说明

由于您的数据中包含三个可变列,因此您可能需要使用以下命令绘制所有成对维度:

sns.pairplot(vars=["one","two","three"], data=df, hue="key1")

在此处输入图片说明

https://rasbt.github.io/mlxtend/user_guide/plotting/category_scatter/是另一种选择。


19

使用plt.scatter,我只能想到一种:使用代理艺术家:

df = pd.DataFrame(np.random.normal(10,1,30).reshape(10,3), index = pd.date_range('2010-01-01', freq = 'M', periods = 10), columns = ('one', 'two', 'three'))
df['key1'] = (4,4,4,6,6,6,8,8,8,8)
fig1 = plt.figure(1)
ax1 = fig1.add_subplot(111)
x=ax1.scatter(df['one'], df['two'], marker = 'o', c = df['key1'], alpha = 0.8)

ccm=x.get_cmap()
circles=[Line2D(range(1), range(1), color='w', marker='o', markersize=10, markerfacecolor=item) for item in ccm((array([4,6,8])-4.0)/4)]
leg = plt.legend(circles, ['4','6','8'], loc = "center left", bbox_to_anchor = (1, 0.5), numpoints = 1)

结果是:

在此处输入图片说明


10

您可以使用df.plot.scatter,并将数组传递给定义每个点颜色的c =参数:

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
df = pd.DataFrame(np.random.normal(10,1,30).reshape(10,3), index = pd.date_range('2010-01-01', freq = 'M', periods = 10), columns = ('one', 'two', 'three'))
df['key1'] = (4,4,4,6,6,6,8,8,8,8)
colors = np.where(df["key1"]==4,'r','-')
colors[df["key1"]==6] = 'g'
colors[df["key1"]==8] = 'b'
print(colors)
df.plot.scatter(x="one",y="two",c=colors)
plt.show()

在此处输入图片说明


4

您也可以尝试专注于声明式可视化的Altairggpot

import numpy as np
import pandas as pd
np.random.seed(1974)

# Generate Data
num = 20
x, y = np.random.random((2, num))
labels = np.random.choice(['a', 'b', 'c'], num)
df = pd.DataFrame(dict(x=x, y=y, label=labels))

Altair代码

from altair import Chart
c = Chart(df)
c.mark_circle().encode(x='x', y='y', color='label')

在此处输入图片说明

ggplot代码

from ggplot import *
ggplot(aes(x='x', y='y', color='label'), data=df) +\
geom_point(size=50) +\
theme_bw()

在此处输入图片说明


3

从matplotlib 3.1开始,您可以使用.legend_elements()自动图例创建中显示了一个示例。优点是可以使用单个分散呼叫。

在这种情况下:

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

df = pd.DataFrame(np.random.normal(10,1,30).reshape(10,3), 
                  index = pd.date_range('2010-01-01', freq = 'M', periods = 10), 
                  columns = ('one', 'two', 'three'))
df['key1'] = (4,4,4,6,6,6,8,8,8,8)


fig, ax = plt.subplots()
sc = ax.scatter(df['one'], df['two'], marker = 'o', c = df['key1'], alpha = 0.8)
ax.legend(*sc.legend_elements())
plt.show()

在此处输入图片说明

如果键不是直接以数字形式给出的,则看起来像

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

df = pd.DataFrame(np.random.normal(10,1,30).reshape(10,3), 
                  index = pd.date_range('2010-01-01', freq = 'M', periods = 10), 
                  columns = ('one', 'two', 'three'))
df['key1'] = list("AAABBBCCCC")

labels, index = np.unique(df["key1"], return_inverse=True)

fig, ax = plt.subplots()
sc = ax.scatter(df['one'], df['two'], marker = 'o', c = index, alpha = 0.8)
ax.legend(sc.legend_elements()[0], labels)
plt.show()

在此处输入图片说明


我收到一个错误消息,说“ PathCollection”对象没有属性“ legends_elements”。我的代码如下。fig, ax = plt.subplots(1, 1, figsize = (4,4)) scat = ax.scatter(rand_jitter(important_dataframe["workout_type_int"], jitter = 0.04), important_dataframe["distance"], c = color_list, marker = 'o', alpha = 0.9) print(scat.legends_elements()) #ax.legend(*scat.legend_elements())
Nandish Patel

1
@NandishPatel检查此答案的第一句。还要确保不要混淆legends_elementslegend_elements
ImportanceOfBeingErnest

是的,谢谢。那是一个错字(传说/传说)。自最近6个小时以来,我一直在从事某些工作,因此我没有出现过Matplotlib版本。我以为我正在使用最新的。我对文档说存在这种方法感到困惑,但是代码却给出了错误。再次感谢你。我现在可以睡觉了。
Nandish Patel


1

seaborn具有包装功能scatterplot,可以更高效地完成包装。

sns.scatterplot(data = df, x = 'one', y = 'two', data =  'key1'])
By using our site, you acknowledge that you have read and understand our Cookie Policy and Privacy Policy.
Licensed under cc by-sa 3.0 with attribution required.