将Spark DataFrame列转换为python列表


103

我在具有两列mvv和count的数据帧上工作。

+---+-----+
|mvv|count|
+---+-----+
| 1 |  5  |
| 2 |  9  |
| 3 |  3  |
| 4 |  1  |

我想获得两个包含mvv值和计数值的列表。就像是

mvv = [1,2,3,4]
count = [5,9,3,1]

因此,我尝试了以下代码:第一行应返回python行列表。我想看第一个值:

mvv_list = mvv_count_df.select('mvv').collect()
firstvalue = mvv_list[0].getInt(0)

但是我在第二行收到一条错误消息:

AttributeError:getInt


从Spark 2.3开始,此代码是最快且最不可能引起OutOfMemory异常的代码:list(df.select('mvv').toPandas()['mvv'])Arrow已集成到PySpark中,从而toPandas大大加快了速度。如果您使用的是Spark 2.3+,请不要使用其他方法。请参阅我的答案以获取更多基准测试详细信息。
Powers

Answers:


140

明白了,为什么这种方式无法正常工作。首先,您尝试从类型获取整数,collect的输出如下所示:

>>> mvv_list = mvv_count_df.select('mvv').collect()
>>> mvv_list[0]
Out: Row(mvv=1)

如果您采取这样的做法:

>>> firstvalue = mvv_list[0].mvv
Out: 1

您将获得mvv价值。如果您需要数组的所有信息,则可以采取以下方法:

>>> mvv_array = [int(row.mvv) for row in mvv_list.collect()]
>>> mvv_array
Out: [1,2,3,4]

但是,如果对另一列尝试相同的操作,则会得到:

>>> mvv_count = [int(row.count) for row in mvv_list.collect()]
Out: TypeError: int() argument must be a string or a number, not 'builtin_function_or_method'

发生这种情况是因为它count是一种内置方法。并且该列的名称与相同count。一种解决方法是将列名称更改count_count

>>> mvv_list = mvv_list.selectExpr("mvv as mvv", "count as _count")
>>> mvv_count = [int(row._count) for row in mvv_list.collect()]

但是不需要此解决方法,因为您可以使用字典语法访问列:

>>> mvv_array = [int(row['mvv']) for row in mvv_list.collect()]
>>> mvv_count = [int(row['count']) for row in mvv_list.collect()]

它将最终成功!


它适用于第一列,但不适用于我认为的列计数,因为(spark的函数计数)
a.moussa

你能加上你在做什么吗?在评论中添加此处。
Thiago Baldim

谢谢您的答复,所以这一行工作mvv_list = [mvv_count.select('mvv')。collect()中的i的int(i.mvv)],但不是此count_list = [mvv_count中的i的int(i.i.count) .select('count')。collect()]返回无效的语法
a.moussa '16

不需要select('count')像这样添加此用法:count_list = [int(i.count) for i in mvv_list.collect()]我将示例添加到响应中。
Thiago Baldim

1
@ a.moussa的[i.['count'] for i in mvv_list.collect()]工作是明确使用名为“ count”的列而不是count函数
user989762

103

紧随其后的是一支衬垫,列出了您想要的清单。

mvv = mvv_count_df.select("mvv").rdd.flatMap(lambda x: x).collect()

3
在性能方面,此解决方案比您的解决方案要快得多。mvv_list = [int(i.mvv)for i in mvv_count.select('mvv')。collect()]
Chanaka Fernando

这是迄今为止我所见过的最好的解决方案。谢谢。
hui chen

22

这将为您提供所有元素作为列表。

mvv_list = list(
    mvv_count_df.select('mvv').toPandas()['mvv']
)

1
这是Spark 2.3+最快,最有效的解决方案。请参阅我的答案中的基准测试结果。
Powers

15

以下代码将为您提供帮助

mvv_count_df.select('mvv').rdd.map(lambda row : row[0]).collect()

3
这应该是公认的答案。原因是您在整个过程中都停留在spark上下文中,然后在结束时进行收集,而不是较早地退出spark上下文,这可能会导致较大的收集,具体取决于您的操作。
AntiPawn79

15

根据我的数据,我得到了这些基准:

>>> data.select(col).rdd.flatMap(lambda x: x).collect()

0.52秒

>>> [row[col] for row in data.collect()]

0.271秒

>>> list(data.select(col).toPandas()[col])

0.427秒

结果是一样的


1
如果您使用toLocalIterator代替,collect则应提高内存使用效率[row[col] for row in data.toLocalIterator()]
oglop,

5

如果出现以下错误:

AttributeError:“列表”对象没有属性“收集”

此代码将解决您的问题:

mvv_list = mvv_count_df.select('mvv').collect()

mvv_array = [int(i.mvv) for i in mvv_list]

我也收到该错误,此解决方案解决了该问题。但是为什么我得到了错误?(许多其他人似乎没有得到!)
bikashg

1

我进行了基准分析,这list(mvv_count_df.select('mvv').toPandas()['mvv'])是最快的方法。我很惊讶

我使用5个节点的i3.xlarge集群(每个节点具有30.5 GB的RAM和4个内核)和Spark 2.4.5对10万/亿个行数据集运行了不同的方法。数据以单列均匀分布在20个快速压缩的Parquet文件中。

这是基准测试结果(运行时间以秒为单位):

+-------------------------------------------------------------+---------+-------------+
|                          Code                               | 100,000 | 100,000,000 |
+-------------------------------------------------------------+---------+-------------+
| df.select("col_name").rdd.flatMap(lambda x: x).collect()    |     0.4 | 55.3        |
| list(df.select('col_name').toPandas()['col_name'])          |     0.4 | 17.5        |
| df.select('col_name').rdd.map(lambda row : row[0]).collect()|     0.9 | 69          |
| [row[0] for row in df.select('col_name').collect()]         |     1.0 | OOM         |
| [r[0] for r in mid_df.select('col_name').toLocalIterator()] |     1.2 | *           |
+-------------------------------------------------------------+---------+-------------+

* cancelled after 800 seconds

在驱动程序节点上收集数据时要遵循的黄金法则:

  • 尝试用其他方法解决问题。将数据收集到驱动程序节点非常昂贵,无法利用Spark集群的功能,因此应尽可能避免。
  • 收集尽可能少的行。在收集数据之前,对列进行聚合,重复数据删除,过滤和修剪。尽可能少地将数据发送到驱动程序节点。

toPandas 在Spark 2.3中得到了显着改进。如果您使用的Spark版本早于2.3,则可能不是最佳方法。

有关更多详细信息/基准测试结果,请参见此处


1

可能的解决方案是使用中的collect_list()功能pyspark.sql.functions。这会将所有列值聚合到一个pyspark数组中,该数组在收集时将转换为python列表:

mvv_list   = df.select(collect_list("mvv")).collect()[0][0]
count_list = df.select(collect_list("count")).collect()[0][0] 
By using our site, you acknowledge that you have read and understand our Cookie Policy and Privacy Policy.
Licensed under cc by-sa 3.0 with attribution required.